Skip to content

Commit

Permalink
Irrelevant citations fix for perplexity (#1610)
Browse files Browse the repository at this point in the history
  • Loading branch information
himanshugt16 authored Nov 29, 2024
1 parent 48554f8 commit f2f296b
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 4 deletions.
1 change: 1 addition & 0 deletions kairon/actions/definitions/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async def execute(self, dispatcher: CollectingDispatcher, tracker: Tracker, doma
llm_response, time_taken_llm_response = await llm_processor.predict(user_msg,
user=tracker.sender_id,
invocation='prompt_action',
llm_type=llm_type,
**llm_params)
status = "FAILURE" if llm_response.get("is_failure", False) is True else status
exception = llm_response.get("exception")
Expand Down
2 changes: 1 addition & 1 deletion kairon/shared/cognition/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def validate_data(self, primary_key_col: str, collection_name: str, event_type:
existing_document_map = {
doc["data"].get(primary_key_col): doc
for doc in existing_documents
if doc["data"].get(primary_key_col) is not None # Ensure primary key exists in map
if doc["data"].get(primary_key_col) is not None
}

for row in data:
Expand Down
25 changes: 23 additions & 2 deletions kairon/shared/llm/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ async def predict(self, query: Text, user, *args, **kwargs) -> Tuple:
start_time = time.time()
embeddings_created = False
invocation = kwargs.pop('invocation', None)
llm_type = kwargs.pop('llm_type', DEFAULT_LLM)
try:
query_embedding = await self.get_embedding(query, user, invocation=invocation)
embeddings_created = True
Expand All @@ -114,7 +115,7 @@ async def predict(self, query: Text, user, *args, **kwargs) -> Tuple:
context_prompt = kwargs.pop('context_prompt', DEFAULT_CONTEXT_PROMPT)

context = await self.__attach_similarity_prompt_if_enabled(query_embedding, context_prompt, **kwargs)
answer = await self.__get_answer(query, system_prompt, context, user, invocation=invocation,**kwargs)
answer = await self.__get_answer(query, system_prompt, context, user, invocation=invocation,llm_type = llm_type, **kwargs)
response = {"content": answer}
except Exception as e:
logging.exception(e)
Expand Down Expand Up @@ -205,6 +206,7 @@ async def __get_answer(self, query, system_prompt: Text, context: Text, user, **
use_query_prompt = False
query_prompt = ''
invocation = kwargs.pop('invocation')
llm_type = kwargs.get('llm_type')
if kwargs.get('query_prompt', {}):
query_prompt_dict = kwargs.pop('query_prompt')
query_prompt = query_prompt_dict.get('query_prompt', '')
Expand All @@ -224,6 +226,7 @@ async def __get_answer(self, query, system_prompt: Text, context: Text, user, **
]
if previous_bot_responses:
messages.extend(previous_bot_responses)
query = self.modify_user_message_for_perplexity(query, llm_type, hyperparameters)
messages.append({"role": "user", "content": f"{context} \n{instructions} \nQ: {query} \nA:"}) if instructions \
else messages.append({"role": "user", "content": f"{context} \nQ: {query} \nA:"})
completion, raw_response = await self.__get_completion(messages=messages,
Expand Down Expand Up @@ -396,4 +399,22 @@ def fetch_llm_metadata(bot: str):

metadata[llm_type]['properties']['model']['enum'] = models

return metadata
return metadata

@staticmethod
def modify_user_message_for_perplexity(user_msg: str, llm_type: str, hyperparameters: Dict) -> str:
"""
Modify the user message if the LLM type is 'perplexity' and a search domain filter is provided.
:param user_msg: The original user message.
:param llm_type: The LLM type to check if it's 'perplexity'.
:param hyperparameters: LLM hyperparameters
:return: Modified user message.
"""
if llm_type == 'perplexity':
search_domain_filter = hyperparameters.get('search_domain_filter')
if search_domain_filter:
search_domain_filter_str = "|".join(
[domain.strip() for domain in search_domain_filter if domain.strip()]
)
user_msg = f"{user_msg} inurl:{search_domain_filter_str}"
return user_msg
116 changes: 116 additions & 0 deletions tests/integration_test/action_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime
import os
from urllib.parse import urlencode, urljoin
import urllib

import litellm
from unittest import mock
Expand Down Expand Up @@ -11751,6 +11752,121 @@ def test_prompt_action_response_action_with_prompt_question_from_slot(mock_embed
'response': None, 'image': None, 'attachment': None}
]

@mock.patch.object(litellm, "aembedding", autospec=True)
@mock.patch.object(ActionUtility, 'execute_request_async', autospec=True)
def test_prompt_action_response_action_with_prompt_question_from_slot_perplexity(mock_execute_request_async, mock_embedding, aioresponses):
from uuid6 import uuid7
llm_type = "perplexity"
action_name = "test_prompt_action_response_action_with_prompt_question_from_slot"
bot = "5f50fd0a56b69s8ca10d35d2l"
user = "udit.pandey"
value = "keyvalue"
user_msg = "What kind of language is python?"
bot_content = "Python is a high-level, general-purpose programming language. Its design philosophy emphasizes code readability with the use of significant indentation. Python is dynamically typed and garbage-collected."
generated_text = "Python is dynamically typed, garbage-collected, high level, general purpose programming."
llm_prompts = [
{'name': 'System Prompt',
'data': 'You are a personal assistant. Answer question based on the context below.',
'type': 'system', 'source': 'static', 'is_enabled': True},
{'name': 'History Prompt', 'type': 'user', 'source': 'history', 'is_enabled': True},
{'name': 'Query Prompt', 'data': "What kind of language is python?", 'instructions': 'Rephrase the query.',
'type': 'query', 'source': 'static', 'is_enabled': False},
{'name': 'Similarity Prompt',
'instructions': 'Answer question based on the context above, if answer is not in the context go check previous logs.',
'type': 'user', 'source': 'bot_content', 'data': 'python',
'hyperparameters': {"top_results": 10, "similarity_threshold": 0.70},
'is_enabled': True}
]
mock_execute_request_async.return_value = (
{
'formatted_response': 'Python is dynamically typed, garbage-collected, high level, general purpose programming.',
'response': 'Python is dynamically typed, garbage-collected, high level, general purpose programming.'},
200,
mock.ANY,
mock.ANY
)
embedding = list(np.random.random(OPENAI_EMBEDDING_OUTPUT))
mock_embedding.return_value = litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}]})
expected_body = {'messages': [
{'role': 'system', 'content': 'You are a personal assistant. Answer question based on the context below.\n'},
{'role': 'user', 'content': 'hello'}, {'role': 'assistant', 'content': 'how are you'}, {'role': 'user',
'content': "\nInstructions on how to use Similarity Prompt:\n['Python is a high-level, general-purpose programming language. Its design philosophy emphasizes code readability with the use of significant indentation. Python is dynamically typed and garbage-collected.']\nAnswer question based on the context above, if answer is not in the context go check previous logs.\n \nQ: What kind of language is python? \nA:"}],
'metadata': {'user': 'udit.pandey', 'bot': '5f50fd0a56b698ca10d35d2l', 'invocation': 'prompt_action'},
'api_key': 'keyvalue',
'num_retries': 3, 'temperature': 0.0, 'max_tokens': 300, 'model': 'gpt-4o-mini', 'top_p': 0.0, 'n': 1,
'stop': None, 'presence_penalty': 0.0, 'frequency_penalty': 0.0, 'logit_bias': {}}
aioresponses.add(
url=urljoin(Utility.environment['llm']['url'],
f"/{bot}/completion/{llm_type}"),
method="POST",
status=200,
payload={'formatted_response': generated_text, 'response': generated_text},
body=json.dumps(expected_body)
)
aioresponses.add(
url=f"{Utility.environment['vector']['db']}/collections/{bot}_python_faq_embd/points/search",
body={'vector': embedding},
payload={'result': [{'id': uuid7().__str__(), 'score': 0.80, 'payload': {'content': bot_content}}]},
method="POST",
status=200
)
hyperparameters = Utility.get_llm_hyperparameters("perplexity")
hyperparameters['search_domain_filter'] = ["domain1.com", "domain2.com"]
Actions(name=action_name, type=ActionType.prompt_action.value, bot=bot, user=user).save()
BotSettings(llm_settings=LLMSettings(enable_faq=True), bot=bot, user=user).save()
PromptAction(name=action_name, bot=bot, user=user, num_bot_responses=2, llm_prompts=llm_prompts, llm_type="perplexity", hyperparameters = hyperparameters,
user_question=UserQuestion(type="from_slot", value="prompt_question")).save()
llm_secret = LLMSecret(
llm_type=llm_type,
api_key=value,
models=["perplexity/llama-3.1-sonar-small-128k-online", "perplexity/llama-3.1-sonar-large-128k-online", "perplexity/llama-3.1-sonar-huge-128k-online"],
bot=bot,
user=user
)
llm_secret.save()
llm_secret = LLMSecret(
llm_type="openai",
api_key="api_key",
models=["gpt-3.5-turbo", "gpt-4o-mini"],
bot=bot,
user=user
)
llm_secret.save()
request_object = json.load(open("tests/testing_data/actions/action-request.json"))
request_object["tracker"]["slots"] = {"bot": bot, "prompt_question": user_msg}
request_object["next_action"] = action_name
request_object["tracker"]["sender_id"] = user
request_object['tracker']['events'] = [{"event": "user", 'text': 'hello',
"data": {"elements": '', "quick_replies": '', "buttons": '',
"attachment": '', "image": '', "custom": ''}},
{'event': 'bot', "text": "how are you",
"data": {"elements": '', "quick_replies": '', "buttons": '',
"attachment": '', "image": '', "custom": ''}}]
response = client.post("/webhook", json=request_object)
response_json = response.json()
mock_execute_request_async.assert_called_once_with(
http_url=f"{Utility.environment['llm']['url']}/{urllib.parse.quote(bot)}/completion/{llm_type}",
request_method="POST",
request_body={
'messages': [{'role': 'system', 'content': 'You are a personal assistant. Answer question based on the context below.\n'},
{'role': 'user', 'content': 'hello'},
{'role': 'assistant', 'content': 'how are you'},
{'role': 'user', 'content': "\nInstructions on how to use Similarity Prompt:\n['Python is a high-level, general-purpose programming language. Its design philosophy emphasizes code readability with the use of significant indentation. Python is dynamically typed and garbage-collected.']\nAnswer question based on the context above, if answer is not in the context go check previous logs.\n \nQ: What kind of language is python? inurl:domain1.com|domain2.com \nA:"}],
'hyperparameters': hyperparameters,
'user': user,
'invocation': "prompt_action"
},
timeout=Utility.environment['llm'].get('request_timeout', 30)
)
called_args = mock_execute_request_async.call_args
user_message = called_args.kwargs['request_body']['messages'][-1]['content']
assert "inurl:domain1.com|domain2.com" in user_message
assert response_json['events'] == [
{'event': 'slot', 'timestamp': None, 'name': 'kairon_action_response', 'value': generated_text}]
assert response_json['responses'] == [
{'text': generated_text, 'buttons': [], 'elements': [], 'custom': {}, 'template': None,
'response': None, 'image': None, 'attachment': None}
]

@mock.patch.object(litellm, "aembedding", autospec=True)
def test_prompt_action_response_action_with_prompt_question_from_slot_different_embedding_completion(mock_embedding, aioresponses):
Expand Down
74 changes: 73 additions & 1 deletion tests/testing_data/llm_metadata.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,76 @@ anthropic:
type: string
default: "claude-3-haiku-20240307"
enum: ["claude-3-opus-20240229", "claude-3-5-sonnet-20240620", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"]
description: "The model hyperparameter is the ID of the Anthropic model or a custom model that you have trained or fine-tuned."
description: "The model hyperparameter is the ID of the Anthropic model or a custom model that you have trained or fine-tuned."
perplexity:
$schema: "https://json-schema.org/draft/2020-12/schema"
type: "object"
description: "Perplexity AI Models for Prompt"
properties:
temperature:
type: "number"
default: 0.2
minimum: 0.0
maximum: 2.0
description: "The temperature hyperparameter controls the creativity or randomness of the generated responses."
max_tokens:
anyOf:
- type: "integer"
minimum: 5
maximum: 127072
- type: "null"
type:
- "integer"
- "null"
default: null
description: "The max_tokens hyperparameter limits the length of generated responses in chat completion"
model:
type: "string"
default: "perplexity/llama-3.1-sonar-small-128k-online"
enum: ["perplexity/llama-3.1-sonar-small-128k-online", "perplexity/llama-3.1-sonar-large-128k-online", "perplexity/llama-3.1-sonar-huge-128k-online"]
search_domain_filter:
anyOf:
- type: "array"
maxItems: 3
items:
type: "string"
- type: "null"
type:
- "array"
- "null"
default: null
description: "The search domain filter hyperparameter is used to specify list of domain to be used by online models."
search_recency_filter:
anyOf:
- type: "string"
enum: ["month", "week", "day", "hour"]
- type: "null"
type:
- "string"
- "null"
default: null
description: "return results from specified time interval"
top_p:
type: "number"
default: 0.9
minimum: 0.0
maximum: 1.0
description: "The top_p hyperparameter is a value that controls the diversity of the generated responses."
top_k:
type: "integer"
default: 0
minimum: 0
maximum: 2048
description: "The top_k hyperparameter controls the number of token to keep for top_k filtering"
presence_penalty:
type: "number"
default: 0.0
minimum: -2.0
maximum: 2.0
description: "The presence_penalty hyperparameter penalizes the model for generating words that are not present in the context or input prompt."
frequency_penalty:
type: "number"
default: 0.0
minimum: -2.0
maximum: 2.0
description: "The frequency_penalty hyperparameter penalizes the model for generating words that have already been generated in the current response."

0 comments on commit f2f296b

Please sign in to comment.