Skip to content

Commit

Permalink
Model training time optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
himanshugt16 committed Oct 28, 2024
1 parent 9a7e824 commit 5c8eb14
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 56 deletions.
109 changes: 76 additions & 33 deletions kairon/shared/llm/processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from secrets import randbelow, choice
from typing import Text, Dict, List, Tuple
from typing import Text, Dict, List, Tuple, Union
from urllib.parse import urljoin

import litellm
Expand Down Expand Up @@ -54,27 +54,49 @@ async def train(self, user, *args, **kwargs) -> Dict:
await self.__delete_collections()
count = 0
processor = CognitionDataProcessor()
collection_groups = list(CognitionData.objects.aggregate([
{'$match': {'bot': self.bot}},
{'$group': {'_id': "$collection", 'content': {'$push': "$$ROOT"}}},
{'$project': {'collection': "$_id", 'content': 1, '_id': 0}}
]))
for collections in collection_groups:
collection = f"{self.bot}_{collections['collection']}{self.suffix}" if collections[
'collection'] else f"{self.bot}{self.suffix}"
batch_size = 100

collections_data = CognitionData.objects(bot=self.bot)
collection_groups = {}
for content in collections_data:
content_dict = content.to_mongo()
collection_name = content_dict.get('collection') or ""
if collection_name not in collection_groups:
collection_groups[collection_name] = []
collection_groups[collection_name].append(content_dict)

for collection_name, contents in collection_groups.items():
collection = f"{self.bot}_{collection_name}{self.suffix}" if collection_name else f"{self.bot}{self.suffix}"
await self.__create_collection__(collection)
for content in tqdm(collections['content'], desc="Training FAQ"):
if content['content_type'] == CognitionDataType.json.value:
metadata = processor.find_matching_metadata(self.bot, content['data'], content.get('collection'))
search_payload, embedding_payload = Utility.retrieve_search_payload_and_embedding_payload(
content['data'], metadata)
else:
search_payload, embedding_payload = {'content': content["data"]}, content["data"]
embeddings = await self.get_embedding(embedding_payload, user, invocation=invocation)
points = [{'id': content['vector_id'], 'vector': embeddings, 'payload': search_payload}]

for i in tqdm(range(0, len(contents), batch_size), desc="Training FAQ"):
batch_contents = contents[i:i + batch_size]

embedding_payloads = []
search_payloads = []
vector_ids = []

for content in batch_contents:
if content['content_type'] == CognitionDataType.json.value:
metadata = processor.find_matching_metadata(self.bot, content['data'],
content.get('collection'))
search_payload, embedding_payload = Utility.retrieve_search_payload_and_embedding_payload(
content['data'], metadata)
else:
search_payload, embedding_payload = {'content': content["data"]}, content["data"]

embedding_payloads.append(embedding_payload)
search_payloads.append(search_payload)
vector_ids.append(content['vector_id'])

embeddings = await self.get_embedding(embedding_payloads, user, invocation=invocation)

points = [{'id': vector_ids[idx], 'vector': embeddings[idx], 'payload': search_payloads[idx]}
for idx in range(len(vector_ids))]
await self.__collection_upsert__(collection, {'points': points},
err_msg="Unable to train FAQ! Contact support")
count += 1
count += len(batch_contents)

return {"faq": count}

async def predict(self, query: Text, user, *args, **kwargs) -> Tuple:
Expand Down Expand Up @@ -104,21 +126,42 @@ async def predict(self, query: Text, user, *args, **kwargs) -> Tuple:
elapsed_time = end_time - start_time
return response, elapsed_time

def truncate_text(self, text: Text) -> Text:
def truncate_text(self, texts: List[Text]) -> List[Text]:
"""
Truncate text to 8191 tokens for openai
Truncate multiple texts to 8191 tokens for openai
"""
tokens = self.tokenizer.encode(text)[:self.EMBEDDING_CTX_LENGTH]
return self.tokenizer.decode(tokens)

async def get_embedding(self, text: Text, user, **kwargs) -> List[float]:
truncated_text = self.truncate_text(text)
result = await litellm.aembedding(model="text-embedding-3-small",
input=[truncated_text],
metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")},
api_key=self.llm_secret_embedding.get('api_key'),
num_retries=3)
return result["data"][0]["embedding"]
truncated_texts = []

for text in texts:
tokens = self.tokenizer.encode(text)[:self.EMBEDDING_CTX_LENGTH]
truncated_texts.append(self.tokenizer.decode(tokens))

return truncated_texts

async def get_embedding(self, texts: Union[Text, List[Text]], user, **kwargs):
"""
Get embeddings for a batch of texts.
"""
is_single_text = isinstance(texts, str)
if is_single_text:
texts = [texts]

truncated_texts = self.truncate_text(texts)

result = await litellm.aembedding(
model="text-embedding-3-small",
input=truncated_texts,
metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")},
api_key=self.llm_secret_embedding.get('api_key'),
num_retries=3
)

embeddings = [embedding["embedding"] for embedding in result["data"]]

if is_single_text:
return embeddings[0]

return embeddings

async def __parse_completion_response(self, response, **kwargs):
if kwargs.get("stream"):
Expand Down Expand Up @@ -322,4 +365,4 @@ def fetch_llm_metadata(bot: str):

metadata[llm_type]['properties']['model']['enum'] = models

return metadata
return metadata
4 changes: 3 additions & 1 deletion tests/integration_test/action_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12802,7 +12802,7 @@ def __mock_fetch_similar(*args, **kwargs):
@mock.patch.object(litellm, "aembedding", autospec=True)
@mock.patch("kairon.shared.actions.utils.ActionUtility.compose_response", autospec=True)
@mock.patch("kairon.shared.rest_client.AioRestClient.request", autospec=True)
def test_prompt_action_set_slots(mock_search, mock_slot_set, mock_mock_embedding, mock_completion):
def test_prompt_action_set_slots(mock_search, mock_slot_set, mock_embedding, mock_completion):
action_name = "kairon_faq_action"
bot = "5u80fd0a56c908ca10d35d2sjhjhjhj"
user = "udit.pandey"
Expand All @@ -12828,6 +12828,8 @@ def mock_completion_for_answer(*args, **kwargs):
return litellm.ModelResponse(**{'choices': [{'message': {'content': generated_text, 'role': 'assistant'}}]})

mock_completion.return_value = mock_completion_for_answer()
embedding = list(np.random.random(OPENAI_EMBEDDING_OUTPUT))
mock_embedding.return_value = litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}]})
mock_completion.return_value = litellm.ModelResponse(
**{'choices': [{'message': {'content': generated_text, 'role': 'assistant'}}]})
log1 = ['Slot: api_type', 'evaluation_type: expression', f"data: {generated_text}", 'response: filter']
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_test/data_processor/data_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3980,7 +3980,7 @@ def test_start_training_with_llm_faq(
settings.llm_settings = LLMSettings(enable_faq=True)
settings.save()
embedding = list(np.random.random(1532))
mock_openai.return_value = {'data': [{'embedding': embedding}]}
mock_openai.return_value = {'data': [{'embedding': embedding}, {'embedding': embedding}]}
mock_bot.return_value = {"account": 1}
mock_train.return_value = f"/models/{bot}"
start_training(bot, user)
Expand Down
41 changes: 20 additions & 21 deletions tests/unit_test/llm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,18 @@ async def test_gpt3_faq_embedding_train_payload_text(self, mock_embedding, aiore
bot = "test_embed_faq_text"
user = "test"
value = "nupurkhare"
CognitionSchema(
metadata=[{"column_name": "name", "data_type": "str", "enable_search": True, "create_embeddings": True},
{"column_name": "city", "data_type": "str", "enable_search": False, "create_embeddings": True}],
collection_name="User_details",
bot=bot, user=user
).save()
CognitionSchema(
metadata=[{"column_name": "country", "data_type": "str", "enable_search": True, "create_embeddings": True},
{"column_name": "lang", "data_type": "str", "enable_search": False, "create_embeddings": True},
{"column_name": "role", "data_type": "str", "enable_search": True, "create_embeddings": True}],
collection_name="Country_details",
bot=bot, user=user).save()
test_content = CognitionData(
data={"name": "Nupur", "city": "Pune"},
content_type="json",
collection="User_details",
bot=bot, user=user).save()
CognitionSchema(
metadata=[{"column_name": "name", "data_type": "str", "enable_search": True, "create_embeddings": True},
{"column_name": "city", "data_type": "str", "enable_search": False, "create_embeddings": True}],
collection_name="User_details",
bot=bot, user=user
).save()
test_content_two = CognitionData(
data={"country": "Spain", "lang": "spanish"},
content_type="json",
Expand All @@ -126,6 +121,11 @@ async def test_gpt3_faq_embedding_train_payload_text(self, mock_embedding, aiore
content_type="json",
collection="Country_details",
bot=bot, user=user).save()
test_content = CognitionData(
data={"name": "Nupur", "city": "Pune"},
content_type="json",
collection="User_details",
bot=bot, user=user).save()

llm_secret = LLMSecret(
llm_type="openai",
Expand All @@ -138,9 +138,9 @@ async def test_gpt3_faq_embedding_train_payload_text(self, mock_embedding, aiore
llm_secret.save()

embedding = list(np.random.random(LLMProcessor.__embedding__))
mock_embedding.side_effect = (litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}]}),
litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}]}),
litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}]}))
mock_embedding.side_effect = (
litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}, {'embedding': embedding}]}),
litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}]}))
gpt3 = LLMProcessor(bot, DEFAULT_LLM)
with mock.patch.dict(Utility.environment, {'llm': {"faq": "GPT3_FAQ_EMBED", 'api_key': llm_secret}}):
aioresponses.add(
Expand Down Expand Up @@ -196,16 +196,14 @@ async def test_gpt3_faq_embedding_train_payload_text(self, mock_embedding, aiore
assert list(aioresponses.requests.values())[2][0].kwargs['json'] == {
'name': f"{gpt3.bot}_country_details{gpt3.suffix}",
'vectors': gpt3.vector_config}

assert list(aioresponses.requests.values())[3][0].kwargs['json'] == {
'points': [{'id': test_content_two.vector_id,
'vector': embedding,
'payload': {'country': 'Spain'}}]}
assert list(aioresponses.requests.values())[3][1].kwargs['json'] == {
'points': [{'id': test_content_three.vector_id,
'payload': {'country': 'Spain'}},
{'id': test_content_three.vector_id,
'vector': embedding,
'payload': {'role': 'ds'}}]}

'payload': {'role': 'ds'}}
]}
assert list(aioresponses.requests.values())[4][0].kwargs['json'] == {
'name': f"{gpt3.bot}_user_details{gpt3.suffix}",
'vectors': gpt3.vector_config}
Expand All @@ -216,7 +214,8 @@ async def test_gpt3_faq_embedding_train_payload_text(self, mock_embedding, aiore
assert response['faq'] == 3

expected = {"model": "text-embedding-3-small",
"input": [json.dumps(test_content.data)], 'metadata': {'user': user, 'bot': bot, 'invocation': None},
"input": [json.dumps(test_content.data)],
'metadata': {'user': user, 'bot': bot, 'invocation': None},
"api_key": value,
"num_retries": 3}
assert not DeepDiff(mock_embedding.call_args[1], expected, ignore_order=True)
Expand Down

0 comments on commit 5c8eb14

Please sign in to comment.