From e7aac7e72523176ce9af2f528d57b5c683bc1f6f Mon Sep 17 00:00:00 2001 From: spandan_mondal Date: Fri, 20 Dec 2024 15:22:25 +0530 Subject: [PATCH 1/3] add missing llm_type based on model in hyperparameters --- kairon/shared/data/data_validation.py | 30 ++++++- .../data_processor/action_serializer_test.py | 83 ++++++++++++++++++- 2 files changed, 108 insertions(+), 5 deletions(-) diff --git a/kairon/shared/data/data_validation.py b/kairon/shared/data/data_validation.py index 0365e852d..e8f711581 100644 --- a/kairon/shared/data/data_validation.py +++ b/kairon/shared/data/data_validation.py @@ -9,6 +9,22 @@ class DataValidation: + model_llm_type_map = None + @staticmethod + def get_model_llm_type_map() -> dict[str, str]: + if DataValidation.model_llm_type_map: + return DataValidation.model_llm_type_map + else: + metadata = Utility.load_yaml(Utility.llm_metadata_file_path) + DataValidation.model_llm_type_map = {} + for llm_type in metadata: + models = metadata[llm_type]['properties']['model']['enum'] + for model in models: + DataValidation.model_llm_type_map[model] = llm_type + + return DataValidation.model_llm_type_map + + @staticmethod def validate_http_action(bot: str, data: dict): action_param_types = {param.value for param in ActionParameterType} @@ -73,12 +89,18 @@ def validate_prompt_action(bot: str, data: dict): data_error.append( f'num_bot_responses should not be greater than 5 and of type int: {data.get("name")}') llm_prompts_errors = DataValidation.validate_llm_prompts(data['llm_prompts']) - if data.get('hyperparameters'): + data_error.extend(llm_prompts_errors) + if hyperparameters := data.get('hyperparameters'): + llm_type = data.get("llm_type") + if not llm_type: + + if model := hyperparameters.get("model"): + data["llm_type"] = DataValidation.get_model_llm_type_map().get(model) + else: + data_error.append("model is required in hyperparameters!") llm_hyperparameters_errors = DataValidation.validate_llm_prompts_hyperparameters( - data.get('hyperparameters'), data.get("llm_type", "openai"), bot) + hyperparameters, data.get("llm_type", "openai"), bot) data_error.extend(llm_hyperparameters_errors) - data_error.extend(llm_prompts_errors) - return data_error @staticmethod diff --git a/tests/unit_test/data_processor/action_serializer_test.py b/tests/unit_test/data_processor/action_serializer_test.py index 1fc0934cb..ea1a0a3c5 100644 --- a/tests/unit_test/data_processor/action_serializer_test.py +++ b/tests/unit_test/data_processor/action_serializer_test.py @@ -1,5 +1,7 @@ import os +from unittest.mock import patch +from deepdiff import DeepDiff from mongoengine import connect from kairon import Utility @@ -417,7 +419,8 @@ def test_validate_prompt_action(): ], "hyperparameters": { "similarity_threshold": 0.5, - "top_results": 5 + "top_results": 5, + "model": "gpt-3.5-turbo", }, "llm_type": "openai" } @@ -716,3 +719,81 @@ def test_action_save_collection_data_list_unknown_data(): ActionSerializer.save_collection_data_list('unknown1', bot, user, [{'data1': 'value1'}]) +def test_prompt_action_validation_missing_model(): + bot = "my_test_bot" + data = { + "num_bot_responses": 3, + "llm_prompts": [ + { + "type": "system", + "source": "static", + "data": "Hello, World!", + "name": "Prompt1", + "hyperparameters": { + "similarity_threshold": 0.5, + "top_results": 5 + } + } + ], + "hyperparameters": { + "similarity_threshold": 0.5, + "top_results": 5, + }, + } + errors = DataValidation.validate_prompt_action(bot, data) + assert errors == ['model is required in hyperparameters!'] + + + + +def test_get_model_llm_type_map(): + result = DataValidation.get_model_llm_type_map() + print(result) + expected = {'gpt-3.5-turbo': 'openai', + 'gpt-4o-mini': 'openai', + 'claude-3-opus-20240229': 'anthropic', + 'claude-3-5-sonnet-20240620': 'anthropic', + 'claude-3-sonnet-20240229': 'anthropic', + 'claude-3-haiku-20240307': 'anthropic', + 'gemini/gemini-1.5-flash': 'gemini', + 'gemini/gemini-pro': 'gemini', + 'perplexity/llama-3.1-sonar-small-128k-online': 'perplexity', + 'perplexity/llama-3.1-sonar-large-128k-online': 'perplexity', + 'perplexity/llama-3.1-sonar-huge-128k-online': 'perplexity'} + + assert not DeepDiff(result, expected, ignore_order=True) + + with patch('kairon.shared.utils.Utility.load_yaml') as yml_load: + DataValidation.get_model_llm_type_map() + yml_load.assert_not_called() + DataValidation.model_llm_type_map = None + DataValidation.get_model_llm_type_map() + yml_load.assert_called_once_with(Utility.llm_metadata_file_path) + + + + +def test_add_llm_type_based_on_model(): + bot = "my_test_bot" + data = { + "num_bot_responses": 3, + "llm_prompts": [ + { + "type": "system", + "source": "static", + "data": "Hello, World!", + "name": "Prompt1", + "hyperparameters": { + "similarity_threshold": 0.5, + "top_results": 5 + } + } + ], + "hyperparameters": { + "similarity_threshold": 0.5, + "top_results": 5, + "model": "gpt-4o-mini", + }, + } + assert not DataValidation.validate_prompt_action(bot, data) + assert data['llm_type'] == 'openai' From ab7deae6fb7a5d9e6f3f2c92f49078dcc9c73ef2 Mon Sep 17 00:00:00 2001 From: spandan_mondal Date: Fri, 20 Dec 2024 15:36:24 +0530 Subject: [PATCH 2/3] add missing llm_type based on model in hyperparameters --- kairon/shared/data/data_validation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/kairon/shared/data/data_validation.py b/kairon/shared/data/data_validation.py index e8f711581..be8d88eef 100644 --- a/kairon/shared/data/data_validation.py +++ b/kairon/shared/data/data_validation.py @@ -91,9 +91,7 @@ def validate_prompt_action(bot: str, data: dict): llm_prompts_errors = DataValidation.validate_llm_prompts(data['llm_prompts']) data_error.extend(llm_prompts_errors) if hyperparameters := data.get('hyperparameters'): - llm_type = data.get("llm_type") - if not llm_type: - + if not data.get("llm_type"): if model := hyperparameters.get("model"): data["llm_type"] = DataValidation.get_model_llm_type_map().get(model) else: From 36bb423c4fc94c6229b8d4da60bd18c4996aea07 Mon Sep 17 00:00:00 2001 From: spandan_mondal Date: Mon, 23 Dec 2024 10:21:06 +0530 Subject: [PATCH 3/3] replaced openai with DEFAULT_LLM --- kairon/shared/data/data_validation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kairon/shared/data/data_validation.py b/kairon/shared/data/data_validation.py index be8d88eef..882dffe8a 100644 --- a/kairon/shared/data/data_validation.py +++ b/kairon/shared/data/data_validation.py @@ -6,6 +6,7 @@ import ast from kairon.shared.callback.data_objects import encrypt_secret +from kairon.shared.data.constant import DEFAULT_LLM class DataValidation: @@ -97,7 +98,7 @@ def validate_prompt_action(bot: str, data: dict): else: data_error.append("model is required in hyperparameters!") llm_hyperparameters_errors = DataValidation.validate_llm_prompts_hyperparameters( - hyperparameters, data.get("llm_type", "openai"), bot) + hyperparameters, data.get("llm_type", DEFAULT_LLM), bot) data_error.extend(llm_hyperparameters_errors) return data_error