From aa7c5c06fa59b166aa14fbb52bca74a7bdf082c1 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Dec 2024 14:56:21 +0100 Subject: [PATCH] Lint --- .../amazon_bedrock/chat/chat_generator.py | 6 +- .../tests/test_chat_generator.py | 64 ++++--------------- 2 files changed, 15 insertions(+), 55 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 0329685fb..1153e581e 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -67,7 +67,7 @@ def __init__( aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, boto3_config: Optional[Dict[str, Any]] = None, ): """ @@ -117,7 +117,7 @@ def __init__( self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name self.stop_words = stop_words or [] - self.streaming_callback = streaming_callback + self.streaming_callback = streaming_callback self.boto3_config = boto3_config def resolve_secret(secret: Optional[Secret]) -> Optional[str]: @@ -164,7 +164,7 @@ def to_dict(self) -> Dict[str, Any]: model=self.model, stop_words=self.stop_words, generation_kwargs=self.generation_kwargs, - streaming_callback=callback_name, + streaming_callback=callback_name, boto3_config=self.boto3_config, ) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index ab980603b..5ccc3083e 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -354,18 +354,9 @@ def test_extract_replies_from_response(self, mock_boto3_session): # Test case 1: Simple text response text_response = { - "output": { - "message": { - "role": "assistant", - "content": [{"text": "This is a test response"}] - } - }, + "output": {"message": {"role": "assistant", "content": [{"text": "This is a test response"}]}}, "stopReason": "complete", - "usage": { - "inputTokens": 10, - "outputTokens": 20, - "totalTokens": 30 - } + "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, } replies = generator.extract_replies_from_response(text_response) @@ -374,32 +365,18 @@ def test_extract_replies_from_response(self, mock_boto3_session): assert replies[0].role == ChatRole.ASSISTANT assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" assert replies[0].meta["finish_reason"] == "complete" - assert replies[0].meta["usage"] == { - "prompt_tokens": 10, - "completion_tokens": 20, - "total_tokens": 30 - } + assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} # Test case 2: Tool use response tool_response = { "output": { "message": { "role": "assistant", - "content": [{ - "toolUse": { - "toolUseId": "123", - "name": "test_tool", - "input": {"key": "value"} - } - }] + "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"key": "value"}}}], } }, "stopReason": "tool_call", - "usage": { - "inputTokens": 15, - "outputTokens": 25, - "totalTokens": 40 - } + "usage": {"inputTokens": 15, "outputTokens": 25, "totalTokens": 40}, } replies = generator.extract_replies_from_response(tool_response) @@ -409,11 +386,7 @@ def test_extract_replies_from_response(self, mock_boto3_session): assert tool_content["name"] == "test_tool" assert tool_content["input"] == {"key": "value"} assert replies[0].meta["finish_reason"] == "tool_call" - assert replies[0].meta["usage"] == { - "prompt_tokens": 15, - "completion_tokens": 25, - "total_tokens": 40 - } + assert replies[0].meta["usage"] == {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40} # Test case 3: Mixed content response mixed_response = { @@ -422,22 +395,12 @@ def test_extract_replies_from_response(self, mock_boto3_session): "role": "assistant", "content": [ {"text": "Let me help you with that. I'll use the search tool to find the answer."}, - { - "toolUse": { - "toolUseId": "456", - "name": "search_tool", - "input": {"query": "test"} - } - } - ] + {"toolUse": {"toolUseId": "456", "name": "search_tool", "input": {"query": "test"}}}, + ], } }, "stopReason": "complete", - "usage": { - "inputTokens": 25, - "outputTokens": 35, - "totalTokens": 60 - } + "usage": {"inputTokens": 25, "outputTokens": 35, "totalTokens": 60}, } replies = generator.extract_replies_from_response(mixed_response) @@ -455,6 +418,7 @@ def test_process_streaming_response(self, mock_boto3_session): generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") streaming_chunks = [] + def test_callback(chunk: StreamingChunk): streaming_chunks.append(chunk) @@ -469,7 +433,7 @@ def test_callback(chunk: StreamingChunk): {"contentBlockDelta": {"delta": {"toolUse": {"input": '"test"}'}}}}, {"contentBlockStop": {}}, {"messageStop": {"stopReason": "complete"}}, - {"metadata": {"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}}} + {"metadata": {"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}}}, ] replies = generator.process_streaming_response(events, test_callback) @@ -485,11 +449,7 @@ def test_callback(chunk: StreamingChunk): assert replies[0].content == "Let me help you." assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" assert replies[0].meta["finish_reason"] == "complete" - assert replies[0].meta["usage"] == { - "prompt_tokens": 10, - "completion_tokens": 20, - "total_tokens": 30 - } + assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} # Check tool use reply tool_content = json.loads(replies[1].content)