-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model training time optimization #1589
Model training time optimization #1589
Conversation
WalkthroughThe changes in this pull request focus on enhancing the Changes
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Outside diff range and nitpick comments (2)
tests/unit_test/llm_test.py (1)
141-143
: Consider using a more descriptive variable name for mock responsesThe mock_embedding.side_effect setup could be more readable with a descriptive variable name for the embedding responses.
-mock_embedding.side_effect = ( - litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}, {'embedding': embedding}]}), - litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}]})) +mock_embedding.side_effect = ( + litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}, {'embedding': embedding}]}), # Multiple embeddings response + litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}]})) # Single embedding responsetests/unit_test/data_processor/data_processor_test.py (1)
3983-3983
: Consider making the embedding dimension configurable.The embedding size (1532) is hardcoded, which could make the test brittle if the model's embedding dimensions change. Consider extracting this as a configuration parameter or constant.
- embedding = list(np.random.random(1532)) + EMBEDDING_DIM = 1532 # Define at module level or fetch from config + embedding = list(np.random.random(EMBEDDING_DIM))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (4)
- kairon/shared/llm/processor.py (4 hunks)
- tests/integration_test/action_service_test.py (2 hunks)
- tests/unit_test/data_processor/data_processor_test.py (1 hunks)
- tests/unit_test/llm_test.py (5 hunks)
🔥 Files not summarized due to errors (2)
- tests/integration_test/action_service_test.py: Error: Server error: no LLM provider could handle the message
- tests/unit_test/data_processor/data_processor_test.py: Error: Server error: no LLM provider could handle the message
🔇 Additional comments (6)
tests/unit_test/llm_test.py (4)
108-113
: LGTM: Schema definition for User_details collectionThe schema definition for the User_details collection is well-structured with appropriate metadata fields and search configurations.
124-128
: LGTM: Test data creation for User_details collectionThe test data creation aligns with the schema definition and includes all required fields.
202-206
: LGTM: Updated payload assertionsThe payload assertions correctly verify the structure of the points data, including the vector IDs and payload content.
217-218
: 🛠️ Refactor suggestionConsider adding validation for JSON serialization
The test assumes the JSON serialization will always succeed. Consider adding test cases for invalid JSON data.
Consider adding a test case for invalid JSON data:
def test_gpt3_faq_embedding_train_payload_invalid_json(self, mock_embedding, aioresponses): test_content = CognitionData( data={"name": float('inf')}, # Invalid JSON data content_type="json", collection="User_details", bot="test_bot", user="test" ).save() # Test implementationtests/integration_test/action_service_test.py (2)
12805-12805
: LGTM! Well-structured test setup.The test function is properly decorated with all necessary mocks for litellm embeddings, action utility, and REST client.
12831-12832
: Verify the embedding dimension constant.The test uses
OPENAI_EMBEDDING_OUTPUT
to generate random embeddings. Let's ensure this constant matches the expected embedding dimension for the model in use.✅ Verification successful
The embedding dimension constant is correctly set.
The test uses
OPENAI_EMBEDDING_OUTPUT = 1536
which matches the expected dimension for the text-embedding-3-small model used throughout the codebase. This is consistent with:
- The
__embedding__
constant in LLMProcessor and Qdrant classes- The default
embedding_size
in OpenAI classifier configurations- The model configuration using "text-embedding-3-small" across all embedding generations
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the embedding dimension constant matches the model specification # Search for the constant definition rg -A 1 "OPENAI_EMBEDDING_OUTPUT\s*=" # Search for any model configuration that might specify the embedding dimension rg -A 5 "embedding.*model.*config"Length of output: 200
Script:
#!/bin/bash # Let's search for any OpenAI embedding model configurations or usages rg -A 3 -B 3 "text-embedding" # Also search for any embedding dimension configurations rg -A 2 "(embedding.*dimension|dimension.*embedding)" # Search for any litellm embedding configurations rg -A 3 "litellm.*embedding"Length of output: 50067
Script:
#!/bin/bash # Let's search for the embedding dimension constant and any related configurations rg -A 2 "__embedding__\s*=" # Also search for any OpenAI embedding model configurations rg -A 2 "embedding_size"Length of output: 1501
@@ -3980,7 +3980,7 @@ def test_start_training_with_llm_faq( | |||
settings.llm_settings = LLMSettings(enable_faq=True) | |||
settings.save() | |||
embedding = list(np.random.random(1532)) | |||
mock_openai.return_value = {'data': [{'embedding': embedding}]} | |||
mock_openai.return_value = {'data': [{'embedding': embedding}, {'embedding': embedding}]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
The test coverage for batch processing is insufficient
The codebase shows batch processing implementation with specific batch sizes:
- OpenAI featurizer uses batch_size=64
- LLM processor uses batch_size=100
However, the current test only mocks a single batch response with 2 embeddings. Additional test cases should be added to verify:
- Processing of full batch (64/100 items)
- Handling of partial batches
- Edge cases with single item
- Error handling for batch failures
Locations needing test coverage:
kairon/shared/nlu/featurizer/openai.py
kairon/shared/llm/processor.py
🔗 Analysis chain
Enhance test coverage for batch processing scenarios.
Since the mock response contains multiple embeddings, consider adding test cases that verify:
- Different batch sizes
- Error handling for partial batch success
- Edge cases with empty or single-item batches
Let's check if these scenarios are covered elsewhere:
Let me search for more context around the test cases to better understand the coverage.
Let me try a different approach to search for the test context.
Let me search for the actual implementation to understand the batch processing behavior.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for related test cases in the test file
rg -A 5 "mock_openai.*embedding" "tests/unit_test/data_processor/data_processor_test.py"
Length of output: 434
Script:
#!/bin/bash
# Search for the test function containing this mock and surrounding context
rg -B 10 -A 10 "mock_openai.*embedding.*return_value.*embedding" "tests/unit_test/data_processor/data_processor_test.py"
# Also search for other test functions that might be testing batch processing
rg -B 5 -A 5 "def test.*batch" "tests/unit_test/data_processor/data_processor_test.py"
Length of output: 207
Script:
#!/bin/bash
# Search for the test function with more context and less specific pattern
rg -B 20 -A 20 "mock_openai.*return_value.*embedding" "tests/unit_test/data_processor/data_processor_test.py"
# Also look for any embedding-related test functions
rg -B 5 -A 5 "def test.*embed" "tests/unit_test/data_processor/data_processor_test.py"
Length of output: 11075
Script:
#!/bin/bash
# Search for the embedding generation implementation
ast-grep --pattern 'def $FUNC($$$) {
$$$
openai$$$embedding$$$
$$$
}'
# Also look for any batch size configuration or constants
rg -B 2 -A 2 "batch.*size"
Length of output: 3252
for collections in collection_groups: | ||
collection = f"{self.bot}_{collections['collection']}{self.suffix}" if collections[ | ||
'collection'] else f"{self.bot}{self.suffix}" | ||
batch_size = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Make batch_size
a configurable parameter
Currently, batch_size
is hardcoded to 100. Consider making it a configurable parameter to allow flexibility for different data volumes and system capabilities.
@@ -322,4 +365,4 @@ | |||
|
|||
metadata[llm_type]['properties']['model']['enum'] = models | |||
|
|||
return metadata | |||
return metadata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid modifying shared metadata directly
In fetch_llm_metadata
, modifying Utility.llm_metadata
directly can lead to unintended side effects elsewhere in the application. Consider creating a deep copy of the metadata before making changes.
Apply this diff to create a copy:
+import copy
def fetch_llm_metadata(bot: str):
- metadata = Utility.llm_metadata
+ metadata = copy.deepcopy(Utility.llm_metadata)
...
Committable suggestion was skipped due to low confidence.
truncated_texts = self.truncate_text(texts) | ||
|
||
result = await litellm.aembedding( | ||
model="text-embedding-3-small", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Make the embedding model configurable
The model name "text-embedding-3-small"
is hardcoded. To enhance flexibility and allow for easy updates or experimentation with different models, consider making the model name a configurable parameter.
Apply this diff to use a configurable model name:
+model_name = self.llm_secret_embedding.get('embedding_model', "text-embedding-3-small")
result = await litellm.aembedding(
- model="text-embedding-3-small",
+ model=model_name,
input=truncated_texts,
...
)
Committable suggestion was skipped due to low confidence.
|
||
embeddings = await self.get_embedding(embedding_payloads, user, invocation=invocation) | ||
|
||
points = [{'id': vector_ids[idx], 'vector': embeddings[idx], 'payload': search_payloads[idx]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add exception handling for embedding retrieval
When calling self.get_embedding
, exceptions from the embedding service may occur (e.g., network issues, API errors). Adding a try-except block can help handle these exceptions gracefully and enhance the robustness of the training process.
Apply this diff to add exception handling:
+try:
embeddings = await self.get_embedding(embedding_payloads, user, invocation=invocation)
+except Exception as e:
+ logging.exception(f"Failed to get embeddings: {e}")
+ continue # Optionally handle the failure (e.g., skip this batch or retry)
Committable suggestion was skipped due to low confidence.
is_single_text = isinstance(texts, str) | ||
if is_single_text: | ||
texts = [texts] | ||
|
||
truncated_texts = self.truncate_text(texts) | ||
|
||
result = await litellm.aembedding( | ||
model="text-embedding-3-small", | ||
input=truncated_texts, | ||
metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")}, | ||
api_key=self.llm_secret_embedding.get('api_key'), | ||
num_retries=3 | ||
) | ||
|
||
embeddings = [embedding["embedding"] for embedding in result["data"]] | ||
|
||
if is_single_text: | ||
return embeddings[0] | ||
|
||
return embeddings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle exceptions in get_embedding
method
The get_embedding
method does not currently handle potential exceptions from litellm.aembedding
. Implementing exception handling here will prevent unhandled exceptions from propagating and improve error reporting.
Apply this diff to handle exceptions:
async def get_embedding(self, texts: Union[Text, List[Text]], user, **kwargs):
...
try:
result = await litellm.aembedding(
model="text-embedding-3-small",
input=truncated_texts,
metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")},
api_key=self.llm_secret_embedding.get('api_key'),
num_retries=3
)
except Exception as e:
logging.exception(f"Embedding retrieval failed: {e}")
raise AppException("Failed to retrieve embeddings") from e
...
Committable suggestion was skipped due to low confidence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approved
Summary by CodeRabbit
New Features
Bug Fixes
Tests