Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model training time optimization #1589

Conversation

himanshugt16
Copy link
Contributor

@himanshugt16 himanshugt16 commented Oct 28, 2024

Summary by CodeRabbit

  • New Features

    • Enhanced batch processing for handling collections and embeddings, improving efficiency.
    • Updated methods to support batch inputs for text truncation and embedding retrieval.
  • Bug Fixes

    • Improved error handling for connection issues in embedding prediction tests.
  • Tests

    • Refined test cases to align with updated data structures and error handling, ensuring robust interactions with external services.

Copy link

coderabbitai bot commented Oct 28, 2024

Walkthrough

The changes in this pull request focus on enhancing the LLMProcessor class in kairon/shared/llm/processor.py. Key modifications include refactoring the train method for batch processing of training data, updating the truncate_text and get_embedding methods to handle lists of texts, and improving error handling in the predict method. Additionally, the test suite for the LLMProcessor class has been updated to reflect changes in data structures and error management, ensuring robust testing for the new functionalities.

Changes

File Path Change Summary
kairon/shared/llm/processor.py Refactored train for batch processing; updated truncate_text to accept lists; modified get_embedding for batch handling; added import for Union.
tests/unit_test/llm_test.py Updated multiple test methods to reflect changes in CognitionSchema and CognitionData; enhanced error handling for connection issues; modified assertions for payload structures.

Poem

🐇 In the meadow where data flows,
The LLMProcessor now brightly glows.
With batches of texts, it swiftly trains,
While handling errors, it never wanes.
A hop, a skip, through code we dance,
Celebrating changes, let’s take a chance! 🌼


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Outside diff range and nitpick comments (2)
tests/unit_test/llm_test.py (1)

141-143: Consider using a more descriptive variable name for mock responses

The mock_embedding.side_effect setup could be more readable with a descriptive variable name for the embedding responses.

-mock_embedding.side_effect = (
-        litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}, {'embedding': embedding}]}),
-        litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}]}))
+mock_embedding.side_effect = (
+        litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}, {'embedding': embedding}]}),  # Multiple embeddings response
+        litellm.EmbeddingResponse(**{'data': [{'embedding': embedding}]}))  # Single embedding response
tests/unit_test/data_processor/data_processor_test.py (1)

3983-3983: Consider making the embedding dimension configurable.

The embedding size (1532) is hardcoded, which could make the test brittle if the model's embedding dimensions change. Consider extracting this as a configuration parameter or constant.

-        embedding = list(np.random.random(1532))
+        EMBEDDING_DIM = 1532  # Define at module level or fetch from config
+        embedding = list(np.random.random(EMBEDDING_DIM))
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 9a7e824 and 5c8eb14.

📒 Files selected for processing (4)
  • kairon/shared/llm/processor.py (4 hunks)
  • tests/integration_test/action_service_test.py (2 hunks)
  • tests/unit_test/data_processor/data_processor_test.py (1 hunks)
  • tests/unit_test/llm_test.py (5 hunks)
🔥 Files not summarized due to errors (2)
  • tests/integration_test/action_service_test.py: Error: Server error: no LLM provider could handle the message
  • tests/unit_test/data_processor/data_processor_test.py: Error: Server error: no LLM provider could handle the message
🔇 Additional comments (6)
tests/unit_test/llm_test.py (4)

108-113: LGTM: Schema definition for User_details collection

The schema definition for the User_details collection is well-structured with appropriate metadata fields and search configurations.


124-128: LGTM: Test data creation for User_details collection

The test data creation aligns with the schema definition and includes all required fields.


202-206: LGTM: Updated payload assertions

The payload assertions correctly verify the structure of the points data, including the vector IDs and payload content.


217-218: 🛠️ Refactor suggestion

Consider adding validation for JSON serialization

The test assumes the JSON serialization will always succeed. Consider adding test cases for invalid JSON data.

Consider adding a test case for invalid JSON data:

def test_gpt3_faq_embedding_train_payload_invalid_json(self, mock_embedding, aioresponses):
    test_content = CognitionData(
        data={"name": float('inf')},  # Invalid JSON data
        content_type="json",
        collection="User_details",
        bot="test_bot",
        user="test"
    ).save()
    # Test implementation
tests/integration_test/action_service_test.py (2)

12805-12805: LGTM! Well-structured test setup.

The test function is properly decorated with all necessary mocks for litellm embeddings, action utility, and REST client.


12831-12832: Verify the embedding dimension constant.

The test uses OPENAI_EMBEDDING_OUTPUT to generate random embeddings. Let's ensure this constant matches the expected embedding dimension for the model in use.

✅ Verification successful

The embedding dimension constant is correctly set.

The test uses OPENAI_EMBEDDING_OUTPUT = 1536 which matches the expected dimension for the text-embedding-3-small model used throughout the codebase. This is consistent with:

  • The __embedding__ constant in LLMProcessor and Qdrant classes
  • The default embedding_size in OpenAI classifier configurations
  • The model configuration using "text-embedding-3-small" across all embedding generations
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the embedding dimension constant matches the model specification

# Search for the constant definition
rg -A 1 "OPENAI_EMBEDDING_OUTPUT\s*=" 

# Search for any model configuration that might specify the embedding dimension
rg -A 5 "embedding.*model.*config" 

Length of output: 200


Script:

#!/bin/bash
# Let's search for any OpenAI embedding model configurations or usages
rg -A 3 -B 3 "text-embedding" 

# Also search for any embedding dimension configurations
rg -A 2 "(embedding.*dimension|dimension.*embedding)"

# Search for any litellm embedding configurations
rg -A 3 "litellm.*embedding"

Length of output: 50067


Script:

#!/bin/bash
# Let's search for the embedding dimension constant and any related configurations
rg -A 2 "__embedding__\s*=" 

# Also search for any OpenAI embedding model configurations
rg -A 2 "embedding_size"

Length of output: 1501

@@ -3980,7 +3980,7 @@ def test_start_training_with_llm_faq(
settings.llm_settings = LLMSettings(enable_faq=True)
settings.save()
embedding = list(np.random.random(1532))
mock_openai.return_value = {'data': [{'embedding': embedding}]}
mock_openai.return_value = {'data': [{'embedding': embedding}, {'embedding': embedding}]}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

The test coverage for batch processing is insufficient

The codebase shows batch processing implementation with specific batch sizes:

  • OpenAI featurizer uses batch_size=64
  • LLM processor uses batch_size=100

However, the current test only mocks a single batch response with 2 embeddings. Additional test cases should be added to verify:

  • Processing of full batch (64/100 items)
  • Handling of partial batches
  • Edge cases with single item
  • Error handling for batch failures

Locations needing test coverage:

  • kairon/shared/nlu/featurizer/openai.py
  • kairon/shared/llm/processor.py
🔗 Analysis chain

Enhance test coverage for batch processing scenarios.

Since the mock response contains multiple embeddings, consider adding test cases that verify:

  1. Different batch sizes
  2. Error handling for partial batch success
  3. Edge cases with empty or single-item batches

Let's check if these scenarios are covered elsewhere:


Let me search for more context around the test cases to better understand the coverage.


Let me try a different approach to search for the test context.


Let me search for the actual implementation to understand the batch processing behavior.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for related test cases in the test file
rg -A 5 "mock_openai.*embedding" "tests/unit_test/data_processor/data_processor_test.py"

Length of output: 434


Script:

#!/bin/bash
# Search for the test function containing this mock and surrounding context
rg -B 10 -A 10 "mock_openai.*embedding.*return_value.*embedding" "tests/unit_test/data_processor/data_processor_test.py"

# Also search for other test functions that might be testing batch processing
rg -B 5 -A 5 "def test.*batch" "tests/unit_test/data_processor/data_processor_test.py"

Length of output: 207


Script:

#!/bin/bash
# Search for the test function with more context and less specific pattern
rg -B 20 -A 20 "mock_openai.*return_value.*embedding" "tests/unit_test/data_processor/data_processor_test.py"

# Also look for any embedding-related test functions
rg -B 5 -A 5 "def test.*embed" "tests/unit_test/data_processor/data_processor_test.py"

Length of output: 11075


Script:

#!/bin/bash
# Search for the embedding generation implementation
ast-grep --pattern 'def $FUNC($$$) {
  $$$
  openai$$$embedding$$$
  $$$
}'

# Also look for any batch size configuration or constants
rg -B 2 -A 2 "batch.*size" 

Length of output: 3252

for collections in collection_groups:
collection = f"{self.bot}_{collections['collection']}{self.suffix}" if collections[
'collection'] else f"{self.bot}{self.suffix}"
batch_size = 100
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Make batch_size a configurable parameter

Currently, batch_size is hardcoded to 100. Consider making it a configurable parameter to allow flexibility for different data volumes and system capabilities.

@@ -322,4 +365,4 @@

metadata[llm_type]['properties']['model']['enum'] = models

return metadata
return metadata
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid modifying shared metadata directly

In fetch_llm_metadata, modifying Utility.llm_metadata directly can lead to unintended side effects elsewhere in the application. Consider creating a deep copy of the metadata before making changes.

Apply this diff to create a copy:

+import copy

 def fetch_llm_metadata(bot: str):
-    metadata = Utility.llm_metadata
+    metadata = copy.deepcopy(Utility.llm_metadata)
     ...

Committable suggestion was skipped due to low confidence.

truncated_texts = self.truncate_text(texts)

result = await litellm.aembedding(
model="text-embedding-3-small",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Make the embedding model configurable

The model name "text-embedding-3-small" is hardcoded. To enhance flexibility and allow for easy updates or experimentation with different models, consider making the model name a configurable parameter.

Apply this diff to use a configurable model name:

+model_name = self.llm_secret_embedding.get('embedding_model', "text-embedding-3-small")
 result = await litellm.aembedding(
-    model="text-embedding-3-small",
+    model=model_name,
     input=truncated_texts,
     ...
 )

Committable suggestion was skipped due to low confidence.

Comment on lines +91 to +94

embeddings = await self.get_embedding(embedding_payloads, user, invocation=invocation)

points = [{'id': vector_ids[idx], 'vector': embeddings[idx], 'payload': search_payloads[idx]}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add exception handling for embedding retrieval

When calling self.get_embedding, exceptions from the embedding service may occur (e.g., network issues, API errors). Adding a try-except block can help handle these exceptions gracefully and enhance the robustness of the training process.

Apply this diff to add exception handling:

+try:
     embeddings = await self.get_embedding(embedding_payloads, user, invocation=invocation)
+except Exception as e:
+    logging.exception(f"Failed to get embeddings: {e}")
+    continue  # Optionally handle the failure (e.g., skip this batch or retry)

Committable suggestion was skipped due to low confidence.

Comment on lines +145 to +164
is_single_text = isinstance(texts, str)
if is_single_text:
texts = [texts]

truncated_texts = self.truncate_text(texts)

result = await litellm.aembedding(
model="text-embedding-3-small",
input=truncated_texts,
metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")},
api_key=self.llm_secret_embedding.get('api_key'),
num_retries=3
)

embeddings = [embedding["embedding"] for embedding in result["data"]]

if is_single_text:
return embeddings[0]

return embeddings
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle exceptions in get_embedding method

The get_embedding method does not currently handle potential exceptions from litellm.aembedding. Implementing exception handling here will prevent unhandled exceptions from propagating and improve error reporting.

Apply this diff to handle exceptions:

 async def get_embedding(self, texts: Union[Text, List[Text]], user, **kwargs):
     ...
     try:
         result = await litellm.aembedding(
             model="text-embedding-3-small",
             input=truncated_texts,
             metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")},
             api_key=self.llm_secret_embedding.get('api_key'),
             num_retries=3
         )
     except Exception as e:
         logging.exception(f"Embedding retrieval failed: {e}")
         raise AppException("Failed to retrieve embeddings") from e
     ...

Committable suggestion was skipped due to low confidence.

Copy link
Collaborator

@hiteshghuge hiteshghuge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants