Skip to content

Commit

Permalink
1. Gave prompt specific hyperparameters, collection top_results and s…
Browse files Browse the repository at this point in the history
…imilarity_threshold.

2. Added unit and integration test cases.
3. Fixed test cases.
  • Loading branch information
Nupur Khare committed Jan 17, 2024
1 parent 91775c1 commit 8e40f57
Show file tree
Hide file tree
Showing 17 changed files with 715 additions and 633 deletions.
27 changes: 11 additions & 16 deletions kairon/actions/definitions/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,10 @@ async def __get_gpt_params(self, k_faq_action_config: dict, dispatcher: Collecti
system_prompt = None
context_prompt = ''
query_prompt = ''
query_prompt_dict = {}
history_prompt = None
is_query_prompt_enabled = False
similarity_prompt_name = None
similarity_prompt_instructions = None
use_similarity_prompt = False
similarity_prompt = {}
params = {}
num_bot_responses = k_faq_action_config['num_bot_responses']
for prompt in k_faq_action_config['llm_prompts']:
Expand All @@ -134,9 +133,11 @@ async def __get_gpt_params(self, k_faq_action_config: dict, dispatcher: Collecti
if prompt['source'] == LlmPromptSource.history.value:
history_prompt = ActionUtility.prepare_bot_responses(tracker, num_bot_responses)
elif prompt['source'] == LlmPromptSource.bot_content.value:
similarity_prompt_name = prompt['name']
similarity_prompt_instructions = prompt['instructions']
use_similarity_prompt = True
similarity_prompt.update({'similarity_prompt_name': prompt['name'],
'similarity_prompt_instructions': prompt['instructions'],
'collection': prompt['collection'],
'use_similarity_prompt': True, 'top_results': prompt.get('top_results'),
'similarity_threshold': prompt.get('similarity_threshold')})
elif prompt['source'] == LlmPromptSource.slot.value:
slot_data = tracker.get_slot(prompt['data'])
context_prompt += f"{prompt['name']}:\n{slot_data}\n"
Expand All @@ -159,21 +160,15 @@ async def __get_gpt_params(self, k_faq_action_config: dict, dispatcher: Collecti
if prompt['instructions']:
query_prompt += f"Instructions on how to use {prompt['name']}:\n{prompt['instructions']}\n\n"
is_query_prompt_enabled = True
query_prompt_dict.update({'query_prompt': query_prompt, 'use_query_prompt': is_query_prompt_enabled,
'hyperparameters': prompt.get('hyperparameters', Utility.get_llm_hyperparameters())})

params["top_results"] = k_faq_action_config.get('top_results', 10)
params["similarity_threshold"] = k_faq_action_config.get('similarity_threshold', 0.70)
params["hyperparameters"] = k_faq_action_config.get('hyperparameters', Utility.get_llm_hyperparameters())
params['enable_response_cache'] = k_faq_action_config.get('enable_response_cache', False)
params["system_prompt"] = system_prompt
params["context_prompt"] = context_prompt
params["query_prompt"] = query_prompt
params["use_query_prompt"] = is_query_prompt_enabled
params["query_prompt"] = query_prompt_dict
params["previous_bot_responses"] = history_prompt
params['use_similarity_prompt'] = use_similarity_prompt
params['similarity_prompt_name'] = similarity_prompt_name
params['similarity_prompt_instructions'] = similarity_prompt_instructions
params["similarity_prompt"] = similarity_prompt
params['instructions'] = k_faq_action_config.get('instructions', [])
params['collection'] = k_faq_action_config.get('collection')
return params

@staticmethod
Expand Down
57 changes: 28 additions & 29 deletions kairon/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,12 +874,40 @@ def check(cls, values):

class LlmPromptRequest(BaseModel):
name: str
top_results: int = 10
similarity_threshold: float = 0.70
hyperparameters: dict = None
data: str = None
collection: str = None
instructions: str = None
type: LlmPromptType
source: LlmPromptSource
is_enabled: bool = True

@validator("similarity_threshold")
def validate_similarity_threshold(cls, v, values, **kwargs):
if not 0.3 <= v <= 1:
raise ValueError("similarity_threshold should be within 0.3 and 1")
return v

@validator("top_results")
def validate_top_results(cls, v, values, **kwargs):
if v > 30:
raise ValueError("top_results should not be greater than 30")
return v

@root_validator
def check(cls, values):
from kairon.shared.utils import Utility

if not values.get('hyperparameters'):
values['hyperparameters'] = {}

for key, value in Utility.get_llm_hyperparameters().items():
if key not in values['hyperparameters']:
values['hyperparameters'][key] = value
return values


class UserQuestionModel(BaseModel):
type: UserMessageType = UserMessageType.from_user_message.value
Expand All @@ -891,53 +919,24 @@ class PromptActionConfigRequest(BaseModel):
num_bot_responses: int = 5
failure_message: str = DEFAULT_NLU_FALLBACK_RESPONSE
user_question: UserQuestionModel = UserQuestionModel()
top_results: int = 10
similarity_threshold: float = 0.70
enable_response_cache: bool = False
hyperparameters: dict = None
llm_prompts: List[LlmPromptRequest]
instructions: List[str] = []
collection: str = None
set_slots: List[SetSlotsUsingActionResponse] = []
dispatch_response: bool = True

@validator("similarity_threshold")
def validate_similarity_threshold(cls, v, values, **kwargs):
if not 0.3 <= v <= 1:
raise ValueError("similarity_threshold should be within 0.3 and 1")
return v

@validator("llm_prompts")
def validate_llm_prompts(cls, v, values, **kwargs):
from kairon.shared.utils import Utility

Utility.validate_kairon_faq_llm_prompts([vars(value) for value in v], ValueError)
return v

@validator("top_results")
def validate_top_results(cls, v, values, **kwargs):
if v > 30:
raise ValueError("top_results should not be greater than 30")
return v

@validator("num_bot_responses")
def validate_num_bot_responses(cls, v, values, **kwargs):
if v > 5:
raise ValueError("num_bot_responses should not be greater than 5")
return v

@root_validator
def check(cls, values):
from kairon.shared.utils import Utility

if not values.get('hyperparameters'):
values['hyperparameters'] = {}

for key, value in Utility.get_llm_hyperparameters().items():
if key not in values['hyperparameters']:
values['hyperparameters'][key] = value
return values


class ColumnMetadata(BaseModel):
column_name: str
Expand Down
22 changes: 14 additions & 8 deletions kairon/importer/validator/file_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,15 +672,7 @@ def __validate_prompt_actions(prompt_actions: list):
continue
if action.get('num_bot_responses') and (action['num_bot_responses'] > 5 or not isinstance(action['num_bot_responses'], int)):
data_error.append(f'num_bot_responses should not be greater than 5 and of type int: {action.get("name")}')
if action.get('top_results') and (action['top_results'] > 30 or not isinstance(action['top_results'], int)):
data_error.append(f'top_results should not be greater than 30 and of type int: {action.get("name")}')
if action.get('similarity_threshold'):
if not (0.3 <= action['similarity_threshold'] <= 1) or not (isinstance(action['similarity_threshold'], float) or isinstance(action['similarity_threshold'], int)):
data_error.append(f'similarity_threshold should be within 0.3 and 1 and of type int or float: {action.get("name")}')
llm_prompts_errors = TrainingDataValidator.__validate_llm_prompts(action['llm_prompts'])
if action.get('hyperparameters') is not None:
llm_hyperparameters_errors = TrainingDataValidator.__validate_llm_prompts_hyperparamters(action.get('hyperparameters'))
data_error.extend(llm_hyperparameters_errors)
data_error.extend(llm_prompts_errors)
if action['name'] in actions_present:
data_error.append(f'Duplicate action found: {action["name"]}')
Expand All @@ -696,6 +688,18 @@ def __validate_llm_prompts(llm_prompts: dict):
history_prompt_count = 0
bot_content_prompt_count = 0
for prompt in llm_prompts:
if prompt.get('top_results') and (prompt['top_results'] > 30 or not isinstance(prompt['top_results'], int)):
error_list.append(f'top_results should not be greater than 30 and of type int: {prompt.get("name")}')
if prompt.get('similarity_threshold'):
if not (0.3 <= prompt['similarity_threshold'] <= 1) or not (
isinstance(prompt['similarity_threshold'], float) or isinstance(prompt['similarity_threshold'],
int)):
error_list.append(
f'similarity_threshold should be within 0.3 and 1 and of type int or float: {prompt.get("name")}')
if prompt.get('hyperparameters') is not None:
llm_hyperparameters_errors = TrainingDataValidator.__validate_llm_prompts_hyperparamters(
prompt.get('hyperparameters'))
error_list.extend(llm_hyperparameters_errors)
if prompt.get('type') == 'system':
system_prompt_count += 1
elif prompt.get('source') == 'history':
Expand Down Expand Up @@ -726,6 +730,8 @@ def __validate_llm_prompts(llm_prompts: dict):
error_list.append('data field in prompts should of type string.')
if not prompt.get('data') and prompt.get('source') == 'static':
error_list.append('data is required for static prompts')
if Utility.check_empty_string(prompt.get('collection')) and prompt.get('source') == 'bot_content':
error_list.append("Collection is required for bot content prompts!")
if system_prompt_count > 1:
error_list.append('Only one system prompt can be present')
if system_prompt_count == 0:
Expand Down
32 changes: 17 additions & 15 deletions kairon/shared/actions/data_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,17 +627,32 @@ def clean(self):

class LlmPrompt(EmbeddedDocument):
name = StringField(required=True)
top_results = IntField(default=10)
similarity_threshold = FloatField(default=0.70)
hyperparameters = DictField(default=Utility.get_llm_hyperparameters)
data = StringField()
collection = StringField(default=None)
instructions = StringField()
type = StringField(required=True, choices=[LlmPromptType.user.value, LlmPromptType.system.value, LlmPromptType.query.value])
source = StringField(choices=[LlmPromptSource.static.value, LlmPromptSource.history.value, LlmPromptSource.bot_content.value,
LlmPromptSource.action.value, LlmPromptSource.slot.value],
default=LlmPromptSource.static.value)
is_enabled = BooleanField(default=True)

def clean(self):
for key, value in Utility.get_llm_hyperparameters().items():
if key not in self.hyperparameters:
self.hyperparameters.update({key: value})

def validate(self, clean=True):
if not 0.3 <= self.similarity_threshold <= 1:
raise ValidationError("similarity_threshold should be within 0.3 and 1")
if self.top_results > 30:
raise ValidationError("top_results should not be greater than 30")
if self.type == LlmPromptType.system.value and self.source != LlmPromptSource.static.value:
raise ValidationError("System prompt must have static source!")
dict_data = self.to_mongo().to_dict()
Utility.validate_llm_hyperparameters(dict_data['hyperparameters'], ValidationError)


class UserQuestion(EmbeddedDocument):
Expand All @@ -651,43 +666,30 @@ class UserQuestion(EmbeddedDocument):
class PromptAction(Auditlog):
name = StringField(required=True)
num_bot_responses = IntField(default=5)
top_results = IntField(default=10)
similarity_threshold = FloatField(default=0.70)
enable_response_cache = BooleanField(default=False)
failure_message = StringField(default=DEFAULT_NLU_FALLBACK_RESPONSE)
user_question = EmbeddedDocumentField(UserQuestion, default=UserQuestion())
bot = StringField(required=True)
user = StringField(required=True)
timestamp = DateTimeField(default=datetime.utcnow)
hyperparameters = DictField(default=Utility.get_llm_hyperparameters)
llm_prompts = EmbeddedDocumentListField(LlmPrompt, required=True)
instructions = ListField(StringField())
collection = StringField(default=None)
set_slots = EmbeddedDocumentListField(SetSlotsFromResponse)
dispatch_response = BooleanField(default=True)
status = BooleanField(default=True)

meta = {"indexes": [{"fields": ["bot", ("bot", "name", "status")]}]}

def clean(self):
for key, value in Utility.get_llm_hyperparameters().items():
if key not in self.hyperparameters:
self.hyperparameters.update({key: value})

def validate(self, clean=True):
if clean:
self.clean()
if self.num_bot_responses > 5:
raise ValidationError("num_bot_responses should not be greater than 5")
if not 0.3 <= self.similarity_threshold <= 1:
raise ValidationError("similarity_threshold should be within 0.3 and 1")
if self.top_results > 30:
raise ValidationError("top_results should not be greater than 30")
if not self.llm_prompts:
raise ValidationError("llm_prompts are required!")
for prompts in self.llm_prompts:
prompts.validate()
dict_data = self.to_mongo().to_dict()
Utility.validate_kairon_faq_llm_prompts(dict_data['llm_prompts'], ValidationError)
Utility.validate_llm_hyperparameters(dict_data['hyperparameters'], ValidationError)


@auditlogger.log
Expand Down
2 changes: 1 addition & 1 deletion kairon/shared/cognition/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def find_matching_metadata(bot: Text, data: Any, collection: Text = None):

@staticmethod
def validate_collection_name(bot: Text, collection: Text):
prompt_action = list(PromptAction.objects(bot=bot, collection__iexact=collection))
prompt_action = list(PromptAction.objects(bot=bot, llm_prompts__collection__iexact=collection))
database_action = list(DatabaseAction.objects(bot=bot, collection__iexact=collection))
if prompt_action:
raise AppException(f'Cannot remove collection {collection} linked to action "{prompt_action[0].name}"!')
Expand Down
5 changes: 0 additions & 5 deletions kairon/shared/data/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5510,14 +5510,9 @@ def edit_prompt_action(self, prompt_action_id: str, request_data: dict, bot: Tex
action.name = request_data.get("name")
action.failure_message = request_data.get("failure_message")
action.user_question = UserQuestion(**request_data.get("user_question"))
action.top_results = request_data.get("top_results")
action.enable_response_cache = request_data.get("enable_response_cache", False)
action.similarity_threshold = request_data.get("similarity_threshold")
action.num_bot_responses = request_data.get('num_bot_responses', 5)
action.hyperparameters = request_data.get('hyperparameters', Utility.get_llm_hyperparameters())
action.llm_prompts = [LlmPrompt(**prompt) for prompt in request_data.get('llm_prompts', [])]
action.instructions = request_data.get('instructions', [])
action.collection = request_data.get('collection')
action.set_slots = request_data.get('set_slots', [])
action.dispatch_response = request_data.get('dispatch_response', True)
action.timestamp = datetime.utcnow()
Expand Down
24 changes: 15 additions & 9 deletions kairon/shared/llm/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,15 @@ async def __get_embedding(self, text: Text) -> List[float]:
return result

async def __get_answer(self, query, system_prompt: Text, context: Text, **kwargs):
query_prompt = kwargs.get('query_prompt')
use_query_prompt = kwargs.get('use_query_prompt')
use_query_prompt = False
query_prompt = ''
hyperparameters = Utility.get_llm_hyperparameters()
if kwargs.get('query_prompt', {}):
query_prompt_dict = kwargs.pop('query_prompt')
query_prompt = query_prompt_dict.get('query_prompt', '')
use_query_prompt = query_prompt_dict.get('use_query_prompt')
hyperparameters = query_prompt_dict.get('hyperparameters', Utility.get_llm_hyperparameters())
previous_bot_responses = kwargs.get('previous_bot_responses')
hyperparameters = kwargs.get('hyperparameters', Utility.get_llm_hyperparameters())
instructions = kwargs.get('instructions', [])
instructions = '\n'.join(instructions)

Expand Down Expand Up @@ -196,13 +201,14 @@ def logs(self):
return self.__logs

async def __attach_similarity_prompt_if_enabled(self, query_embedding, context_prompt, **kwargs):
use_similarity_prompt = kwargs.pop('use_similarity_prompt')
similarity_prompt_name = kwargs.pop('similarity_prompt_name')
similarity_prompt_instructions = kwargs.pop('similarity_prompt_instructions')
limit = kwargs.pop('top_results', 10)
score_threshold = kwargs.pop('similarity_threshold', 0.70)
similarity_prompt = kwargs.pop('similarity_prompt')
use_similarity_prompt = similarity_prompt.get('use_similarity_prompt')
similarity_prompt_name = similarity_prompt.get('similarity_prompt_name')
similarity_prompt_instructions = similarity_prompt.get('similarity_prompt_instructions')
limit = similarity_prompt.get('top_results', 10)
score_threshold = similarity_prompt.get('similarity_threshold', 0.70)
if use_similarity_prompt:
collection_name = f"{self.bot}_{kwargs.get('collection')}{self.suffix}" if kwargs.get('collection') else f"{self.bot}{self.suffix}"
collection_name = f"{self.bot}_{similarity_prompt.get('collection')}{self.suffix}" if similarity_prompt.get('collection') else f"{self.bot}{self.suffix}"
search_result = await self.__collection_search__(collection_name, vector=query_embedding, limit=limit, score_threshold=score_threshold)

similarity_context = "\n".join([item['payload']['content'] for item in search_result['result']])
Expand Down
2 changes: 2 additions & 0 deletions kairon/shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,6 +1781,8 @@ def validate_kairon_faq_llm_prompts(llm_prompts: List, exception_class):
raise exception_class("Name cannot be empty!")
if Utility.check_empty_string(prompt.get('data')) and prompt['source'] == LlmPromptSource.static.value:
raise exception_class("data is required for static prompts!")
if Utility.check_empty_string(prompt.get('collection')) and prompt['source'] == LlmPromptSource.bot_content.value:
raise exception_class("Collection is required for bot content prompts!")
if prompt['type'] == LlmPromptType.query.value and prompt['source'] != LlmPromptSource.static.value:
raise exception_class("Query prompt must have static source!")
if prompt.get('type') == LlmPromptType.system.value:
Expand Down
14 changes: 0 additions & 14 deletions template/use-cases/Hi-Hello-GPT/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,7 @@ jira_action: []
pipedrive_leads_action: []
prompt_action:
- dispatch_response: true
enable_response_cache: false
failure_message: Kindly share more details so I can assist you effectively.
hyperparameters:
frequency_penalty: 0
logit_bias: {}
max_tokens: 300
model: gpt-3.5-turbo
n: 1
presence_penalty: 0
stop: null
stream: false
temperature: 0
top_p: 0
instructions: []
llm_prompts:
- data: 'You are a helpful personal assistant. Answer the question based on the
Expand All @@ -47,9 +35,7 @@ prompt_action:
name: kairon_faq_action
num_bot_responses: 5
set_slots: []
similarity_threshold: 0.7
status: true
top_results: 10
slot_set_action: []
two_stage_fallback: []
zendesk_action: []
Loading

0 comments on commit 8e40f57

Please sign in to comment.