Skip to content

Commit

Permalink
add missing llm_type based on model in hyperparameters (#1675)
Browse files Browse the repository at this point in the history
* add missing llm_type based on model in hyperparameters

* add missing llm_type based on model in hyperparameters

* replaced openai with DEFAULT_LLM

---------

Co-authored-by: spandan_mondal <[email protected]>
  • Loading branch information
hasinaxp and spandan_mondal authored Dec 23, 2024
1 parent c529677 commit 6dc88e8
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 5 deletions.
29 changes: 25 additions & 4 deletions kairon/shared/data/data_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,26 @@
import ast

from kairon.shared.callback.data_objects import encrypt_secret
from kairon.shared.data.constant import DEFAULT_LLM


class DataValidation:
model_llm_type_map = None
@staticmethod
def get_model_llm_type_map() -> dict[str, str]:
if DataValidation.model_llm_type_map:
return DataValidation.model_llm_type_map
else:
metadata = Utility.load_yaml(Utility.llm_metadata_file_path)
DataValidation.model_llm_type_map = {}
for llm_type in metadata:
models = metadata[llm_type]['properties']['model']['enum']
for model in models:
DataValidation.model_llm_type_map[model] = llm_type

return DataValidation.model_llm_type_map


@staticmethod
def validate_http_action(bot: str, data: dict):
action_param_types = {param.value for param in ActionParameterType}
Expand Down Expand Up @@ -73,12 +90,16 @@ def validate_prompt_action(bot: str, data: dict):
data_error.append(
f'num_bot_responses should not be greater than 5 and of type int: {data.get("name")}')
llm_prompts_errors = DataValidation.validate_llm_prompts(data['llm_prompts'])
if data.get('hyperparameters'):
data_error.extend(llm_prompts_errors)
if hyperparameters := data.get('hyperparameters'):
if not data.get("llm_type"):
if model := hyperparameters.get("model"):
data["llm_type"] = DataValidation.get_model_llm_type_map().get(model)
else:
data_error.append("model is required in hyperparameters!")
llm_hyperparameters_errors = DataValidation.validate_llm_prompts_hyperparameters(
data.get('hyperparameters'), data.get("llm_type", "openai"), bot)
hyperparameters, data.get("llm_type", DEFAULT_LLM), bot)
data_error.extend(llm_hyperparameters_errors)
data_error.extend(llm_prompts_errors)

return data_error

@staticmethod
Expand Down
83 changes: 82 additions & 1 deletion tests/unit_test/data_processor/action_serializer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from unittest.mock import patch

from deepdiff import DeepDiff
from mongoengine import connect

from kairon import Utility
Expand Down Expand Up @@ -417,7 +419,8 @@ def test_validate_prompt_action():
],
"hyperparameters": {
"similarity_threshold": 0.5,
"top_results": 5
"top_results": 5,
"model": "gpt-3.5-turbo",
},
"llm_type": "openai"
}
Expand Down Expand Up @@ -716,3 +719,81 @@ def test_action_save_collection_data_list_unknown_data():
ActionSerializer.save_collection_data_list('unknown1', bot, user, [{'data1': 'value1'}])


def test_prompt_action_validation_missing_model():
bot = "my_test_bot"
data = {
"num_bot_responses": 3,
"llm_prompts": [
{
"type": "system",
"source": "static",
"data": "Hello, World!",
"name": "Prompt1",
"hyperparameters": {
"similarity_threshold": 0.5,
"top_results": 5
}
}
],
"hyperparameters": {
"similarity_threshold": 0.5,
"top_results": 5,
},
}
errors = DataValidation.validate_prompt_action(bot, data)
assert errors == ['model is required in hyperparameters!']




def test_get_model_llm_type_map():
result = DataValidation.get_model_llm_type_map()
print(result)
expected = {'gpt-3.5-turbo': 'openai',
'gpt-4o-mini': 'openai',
'claude-3-opus-20240229': 'anthropic',
'claude-3-5-sonnet-20240620': 'anthropic',
'claude-3-sonnet-20240229': 'anthropic',
'claude-3-haiku-20240307': 'anthropic',
'gemini/gemini-1.5-flash': 'gemini',
'gemini/gemini-pro': 'gemini',
'perplexity/llama-3.1-sonar-small-128k-online': 'perplexity',
'perplexity/llama-3.1-sonar-large-128k-online': 'perplexity',
'perplexity/llama-3.1-sonar-huge-128k-online': 'perplexity'}

assert not DeepDiff(result, expected, ignore_order=True)

with patch('kairon.shared.utils.Utility.load_yaml') as yml_load:
DataValidation.get_model_llm_type_map()
yml_load.assert_not_called()
DataValidation.model_llm_type_map = None
DataValidation.get_model_llm_type_map()
yml_load.assert_called_once_with(Utility.llm_metadata_file_path)




def test_add_llm_type_based_on_model():
bot = "my_test_bot"
data = {
"num_bot_responses": 3,
"llm_prompts": [
{
"type": "system",
"source": "static",
"data": "Hello, World!",
"name": "Prompt1",
"hyperparameters": {
"similarity_threshold": 0.5,
"top_results": 5
}
}
],
"hyperparameters": {
"similarity_threshold": 0.5,
"top_results": 5,
"model": "gpt-4o-mini",
},
}
assert not DataValidation.validate_prompt_action(bot, data)
assert data['llm_type'] == 'openai'

0 comments on commit 6dc88e8

Please sign in to comment.