Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model training time optimization #1584

Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 75 additions & 32 deletions kairon/shared/llm/processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from secrets import randbelow, choice
from typing import Text, Dict, List, Tuple
from typing import Text, Dict, List, Tuple, Union
from urllib.parse import urljoin

import litellm
Expand Down Expand Up @@ -54,27 +54,49 @@ async def train(self, user, *args, **kwargs) -> Dict:
await self.__delete_collections()
count = 0
processor = CognitionDataProcessor()
collection_groups = list(CognitionData.objects.aggregate([
{'$match': {'bot': self.bot}},
{'$group': {'_id': "$collection", 'content': {'$push': "$$ROOT"}}},
{'$project': {'collection': "$_id", 'content': 1, '_id': 0}}
]))
for collections in collection_groups:
collection = f"{self.bot}_{collections['collection']}{self.suffix}" if collections[
'collection'] else f"{self.bot}{self.suffix}"
batch_size = 100

Comment on lines +57 to +58
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider making 'batch_size' a configurable parameter

Currently, batch_size is hard-coded to 100. To enhance flexibility and allow for optimization based on different datasets or environments, consider making batch_size a configurable parameter or class attribute.

collections_data = CognitionData.objects(bot=self.bot)
collection_groups = {}
for content in collections_data:
content_dict = content.to_mongo()
collection_name = content_dict.get('collection') or ""
if collection_name not in collection_groups:
collection_groups[collection_name] = []
collection_groups[collection_name].append(content_dict)

for collection_name, contents in collection_groups.items():
collection = f"{self.bot}_{collection_name}{self.suffix}" if collection_name else f"{self.bot}{self.suffix}"
await self.__create_collection__(collection)
for content in tqdm(collections['content'], desc="Training FAQ"):
if content['content_type'] == CognitionDataType.json.value:
metadata = processor.find_matching_metadata(self.bot, content['data'], content.get('collection'))
search_payload, embedding_payload = Utility.retrieve_search_payload_and_embedding_payload(
content['data'], metadata)
else:
search_payload, embedding_payload = {'content': content["data"]}, content["data"]
embeddings = await self.get_embedding(embedding_payload, user, invocation=invocation)
points = [{'id': content['vector_id'], 'vector': embeddings, 'payload': search_payload}]

for i in tqdm(range(0, len(contents), batch_size), desc="Training FAQ"):
batch_contents = contents[i:i + batch_size]

embedding_payloads = []
search_payloads = []
vector_ids = []

for content in batch_contents:
if content['content_type'] == CognitionDataType.json.value:
metadata = processor.find_matching_metadata(self.bot, content['data'],
content.get('collection'))
search_payload, embedding_payload = Utility.retrieve_search_payload_and_embedding_payload(
content['data'], metadata)
else:
search_payload, embedding_payload = {'content': content["data"]}, content["data"]

embedding_payloads.append(embedding_payload)
search_payloads.append(search_payload)
vector_ids.append(content['vector_id'])

Comment on lines +80 to +91
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure 'vector_id' is always present in 'content'

In line 90, vector_ids.append(content['vector_id']), there is a potential risk of a KeyError if 'vector_id' is not present in content. Ensure that every content dictionary contains the 'vector_id' key or handle the case where it might be missing to prevent runtime errors.

Apply this diff to add a safety check:

 for content in batch_contents:
     if content['content_type'] == CognitionDataType.json.value:
         # existing code
     else:
         # existing code

+    if 'vector_id' in content:
         embedding_payloads.append(embedding_payload)
         search_payloads.append(search_payload)
         vector_ids.append(content['vector_id'])
+    else:
+        logging.warning(f"Missing 'vector_id' in content: {content}")

Committable suggestion was skipped due to low confidence.

embeddings = await self.get_embedding(embedding_payloads, user, invocation=invocation)

Comment on lines +92 to +93
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Check API limits for embedding batch sizes

When calling self.get_embedding(embedding_payloads, user, invocation=invocation), ensure that the batch size does not exceed the limits imposed by the embedding API. Large batch sizes might lead to API errors or throttling. Verify the maximum allowed batch size and adjust if necessary.

points = [{'id': vector_ids[idx], 'vector': embeddings[idx], 'payload': search_payloads[idx]}
for idx in range(len(vector_ids))]
await self.__collection_upsert__(collection, {'points': points},
err_msg="Unable to train FAQ! Contact support")
count += 1
count += len(batch_contents)

return {"faq": count}

async def predict(self, query: Text, user, *args, **kwargs) -> Tuple:
Expand Down Expand Up @@ -104,21 +126,42 @@ async def predict(self, query: Text, user, *args, **kwargs) -> Tuple:
elapsed_time = end_time - start_time
return response, elapsed_time

def truncate_text(self, text: Text) -> Text:
def truncate_text(self, texts: List[Text]) -> List[Text]:
"""
Truncate text to 8191 tokens for openai
Truncate multiple texts to 8191 tokens for openai
"""
tokens = self.tokenizer.encode(text)[:self.EMBEDDING_CTX_LENGTH]
return self.tokenizer.decode(tokens)

async def get_embedding(self, text: Text, user, **kwargs) -> List[float]:
truncated_text = self.truncate_text(text)
result = await litellm.aembedding(model="text-embedding-3-small",
input=[truncated_text],
metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")},
api_key=self.llm_secret_embedding.get('api_key'),
num_retries=3)
return result["data"][0]["embedding"]
truncated_texts = []

for text in texts:
tokens = self.tokenizer.encode(text)[:self.EMBEDDING_CTX_LENGTH]
truncated_texts.append(self.tokenizer.decode(tokens))

return truncated_texts

async def get_embedding(self, texts: Union[Text, List[Text]], user, **kwargs):
"""
Get embeddings for a batch of texts.
"""
is_single_text = isinstance(texts, str)
if is_single_text:
texts = [texts]

truncated_texts = self.truncate_text(texts)

Comment on lines +149 to +150
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Optimize 'truncate_text' using list comprehension

You can simplify the truncate_text method by using a list comprehension for improved readability and performance.

Apply this diff to refactor the method:

 def truncate_text(self, texts: List[Text]) -> List[Text]:
     """
     Truncate multiple texts to 8191 tokens for openai
     """
-    truncated_texts = []
-
-    for text in texts:
-        tokens = self.tokenizer.encode(text)[:self.EMBEDDING_CTX_LENGTH]
-        truncated_texts.append(self.tokenizer.decode(tokens))
-
-    return truncated_texts
+    return [
+        self.tokenizer.decode(self.tokenizer.encode(text)[:self.EMBEDDING_CTX_LENGTH])
+        for text in texts
+    ]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
truncated_texts = self.truncate_text(texts)
def truncate_text(self, texts: List[Text]) -> List[Text]:
"""
Truncate multiple texts to 8191 tokens for openai
"""
return [
self.tokenizer.decode(self.tokenizer.encode(text)[:self.EMBEDDING_CTX_LENGTH])
for text in texts
]

result = await litellm.aembedding(
model="text-embedding-3-small",
input=truncated_texts,
metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")},
api_key=self.llm_secret_embedding.get('api_key'),
num_retries=3
)

embeddings = [embedding["embedding"] for embedding in result["data"]]

if is_single_text:
return embeddings[0]

return embeddings
Comment on lines +145 to +164
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for empty or invalid embeddings

In the get_embedding method, consider adding error handling to manage cases where the embedding API returns empty or invalid results. This will enhance the robustness of your code.

Apply this diff to include error checking:

 embeddings = [embedding["embedding"] for embedding in result["data"]]

+if not embeddings or len(embeddings) != len(truncated_texts):
+    raise AppException("Failed to retrieve embeddings for all texts.")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
is_single_text = isinstance(texts, str)
if is_single_text:
texts = [texts]
truncated_texts = self.truncate_text(texts)
result = await litellm.aembedding(
model="text-embedding-3-small",
input=truncated_texts,
metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")},
api_key=self.llm_secret_embedding.get('api_key'),
num_retries=3
)
embeddings = [embedding["embedding"] for embedding in result["data"]]
if is_single_text:
return embeddings[0]
return embeddings
is_single_text = isinstance(texts, str)
if is_single_text:
texts = [texts]
truncated_texts = self.truncate_text(texts)
result = await litellm.aembedding(
model="text-embedding-3-small",
input=truncated_texts,
metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")},
api_key=self.llm_secret_embedding.get('api_key'),
num_retries=3
)
embeddings = [embedding["embedding"] for embedding in result["data"]]
if not embeddings or len(embeddings) != len(truncated_texts):
raise AppException("Failed to retrieve embeddings for all texts.")
if is_single_text:
return embeddings[0]
return embeddings


async def __parse_completion_response(self, response, **kwargs):
if kwargs.get("stream"):
Expand Down
Loading