diff --git a/kairon/shared/data/data_validation.py b/kairon/shared/data/data_validation.py index 0365e852d..882dffe8a 100644 --- a/kairon/shared/data/data_validation.py +++ b/kairon/shared/data/data_validation.py @@ -6,9 +6,26 @@ import ast from kairon.shared.callback.data_objects import encrypt_secret +from kairon.shared.data.constant import DEFAULT_LLM class DataValidation: + model_llm_type_map = None + @staticmethod + def get_model_llm_type_map() -> dict[str, str]: + if DataValidation.model_llm_type_map: + return DataValidation.model_llm_type_map + else: + metadata = Utility.load_yaml(Utility.llm_metadata_file_path) + DataValidation.model_llm_type_map = {} + for llm_type in metadata: + models = metadata[llm_type]['properties']['model']['enum'] + for model in models: + DataValidation.model_llm_type_map[model] = llm_type + + return DataValidation.model_llm_type_map + + @staticmethod def validate_http_action(bot: str, data: dict): action_param_types = {param.value for param in ActionParameterType} @@ -73,12 +90,16 @@ def validate_prompt_action(bot: str, data: dict): data_error.append( f'num_bot_responses should not be greater than 5 and of type int: {data.get("name")}') llm_prompts_errors = DataValidation.validate_llm_prompts(data['llm_prompts']) - if data.get('hyperparameters'): + data_error.extend(llm_prompts_errors) + if hyperparameters := data.get('hyperparameters'): + if not data.get("llm_type"): + if model := hyperparameters.get("model"): + data["llm_type"] = DataValidation.get_model_llm_type_map().get(model) + else: + data_error.append("model is required in hyperparameters!") llm_hyperparameters_errors = DataValidation.validate_llm_prompts_hyperparameters( - data.get('hyperparameters'), data.get("llm_type", "openai"), bot) + hyperparameters, data.get("llm_type", DEFAULT_LLM), bot) data_error.extend(llm_hyperparameters_errors) - data_error.extend(llm_prompts_errors) - return data_error @staticmethod diff --git a/tests/unit_test/data_processor/action_serializer_test.py b/tests/unit_test/data_processor/action_serializer_test.py index 1fc0934cb..ea1a0a3c5 100644 --- a/tests/unit_test/data_processor/action_serializer_test.py +++ b/tests/unit_test/data_processor/action_serializer_test.py @@ -1,5 +1,7 @@ import os +from unittest.mock import patch +from deepdiff import DeepDiff from mongoengine import connect from kairon import Utility @@ -417,7 +419,8 @@ def test_validate_prompt_action(): ], "hyperparameters": { "similarity_threshold": 0.5, - "top_results": 5 + "top_results": 5, + "model": "gpt-3.5-turbo", }, "llm_type": "openai" } @@ -716,3 +719,81 @@ def test_action_save_collection_data_list_unknown_data(): ActionSerializer.save_collection_data_list('unknown1', bot, user, [{'data1': 'value1'}]) +def test_prompt_action_validation_missing_model(): + bot = "my_test_bot" + data = { + "num_bot_responses": 3, + "llm_prompts": [ + { + "type": "system", + "source": "static", + "data": "Hello, World!", + "name": "Prompt1", + "hyperparameters": { + "similarity_threshold": 0.5, + "top_results": 5 + } + } + ], + "hyperparameters": { + "similarity_threshold": 0.5, + "top_results": 5, + }, + } + errors = DataValidation.validate_prompt_action(bot, data) + assert errors == ['model is required in hyperparameters!'] + + + + +def test_get_model_llm_type_map(): + result = DataValidation.get_model_llm_type_map() + print(result) + expected = {'gpt-3.5-turbo': 'openai', + 'gpt-4o-mini': 'openai', + 'claude-3-opus-20240229': 'anthropic', + 'claude-3-5-sonnet-20240620': 'anthropic', + 'claude-3-sonnet-20240229': 'anthropic', + 'claude-3-haiku-20240307': 'anthropic', + 'gemini/gemini-1.5-flash': 'gemini', + 'gemini/gemini-pro': 'gemini', + 'perplexity/llama-3.1-sonar-small-128k-online': 'perplexity', + 'perplexity/llama-3.1-sonar-large-128k-online': 'perplexity', + 'perplexity/llama-3.1-sonar-huge-128k-online': 'perplexity'} + + assert not DeepDiff(result, expected, ignore_order=True) + + with patch('kairon.shared.utils.Utility.load_yaml') as yml_load: + DataValidation.get_model_llm_type_map() + yml_load.assert_not_called() + DataValidation.model_llm_type_map = None + DataValidation.get_model_llm_type_map() + yml_load.assert_called_once_with(Utility.llm_metadata_file_path) + + + + +def test_add_llm_type_based_on_model(): + bot = "my_test_bot" + data = { + "num_bot_responses": 3, + "llm_prompts": [ + { + "type": "system", + "source": "static", + "data": "Hello, World!", + "name": "Prompt1", + "hyperparameters": { + "similarity_threshold": 0.5, + "top_results": 5 + } + } + ], + "hyperparameters": { + "similarity_threshold": 0.5, + "top_results": 5, + "model": "gpt-4o-mini", + }, + } + assert not DataValidation.validate_prompt_action(bot, data) + assert data['llm_type'] == 'openai'