Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Dec 10, 2024
1 parent c7f51a1 commit aa7c5c0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
boto3_config: Optional[Dict[str, Any]] = None,
):
"""
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
self.aws_region_name = aws_region_name
self.aws_profile_name = aws_profile_name
self.stop_words = stop_words or []
self.streaming_callback = streaming_callback
self.streaming_callback = streaming_callback
self.boto3_config = boto3_config

def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
Expand Down Expand Up @@ -164,7 +164,7 @@ def to_dict(self) -> Dict[str, Any]:
model=self.model,
stop_words=self.stop_words,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
streaming_callback=callback_name,
boto3_config=self.boto3_config,
)

Expand Down
64 changes: 12 additions & 52 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,18 +354,9 @@ def test_extract_replies_from_response(self, mock_boto3_session):

# Test case 1: Simple text response
text_response = {
"output": {
"message": {
"role": "assistant",
"content": [{"text": "This is a test response"}]
}
},
"output": {"message": {"role": "assistant", "content": [{"text": "This is a test response"}]}},
"stopReason": "complete",
"usage": {
"inputTokens": 10,
"outputTokens": 20,
"totalTokens": 30
}
"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
}

replies = generator.extract_replies_from_response(text_response)
Expand All @@ -374,32 +365,18 @@ def test_extract_replies_from_response(self, mock_boto3_session):
assert replies[0].role == ChatRole.ASSISTANT
assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0"
assert replies[0].meta["finish_reason"] == "complete"
assert replies[0].meta["usage"] == {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}

# Test case 2: Tool use response
tool_response = {
"output": {
"message": {
"role": "assistant",
"content": [{
"toolUse": {
"toolUseId": "123",
"name": "test_tool",
"input": {"key": "value"}
}
}]
"content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"key": "value"}}}],
}
},
"stopReason": "tool_call",
"usage": {
"inputTokens": 15,
"outputTokens": 25,
"totalTokens": 40
}
"usage": {"inputTokens": 15, "outputTokens": 25, "totalTokens": 40},
}

replies = generator.extract_replies_from_response(tool_response)
Expand All @@ -409,11 +386,7 @@ def test_extract_replies_from_response(self, mock_boto3_session):
assert tool_content["name"] == "test_tool"
assert tool_content["input"] == {"key": "value"}
assert replies[0].meta["finish_reason"] == "tool_call"
assert replies[0].meta["usage"] == {
"prompt_tokens": 15,
"completion_tokens": 25,
"total_tokens": 40
}
assert replies[0].meta["usage"] == {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40}

# Test case 3: Mixed content response
mixed_response = {
Expand All @@ -422,22 +395,12 @@ def test_extract_replies_from_response(self, mock_boto3_session):
"role": "assistant",
"content": [
{"text": "Let me help you with that. I'll use the search tool to find the answer."},
{
"toolUse": {
"toolUseId": "456",
"name": "search_tool",
"input": {"query": "test"}
}
}
]
{"toolUse": {"toolUseId": "456", "name": "search_tool", "input": {"query": "test"}}},
],
}
},
"stopReason": "complete",
"usage": {
"inputTokens": 25,
"outputTokens": 35,
"totalTokens": 60
}
"usage": {"inputTokens": 25, "outputTokens": 35, "totalTokens": 60},
}

replies = generator.extract_replies_from_response(mixed_response)
Expand All @@ -455,6 +418,7 @@ def test_process_streaming_response(self, mock_boto3_session):
generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0")

streaming_chunks = []

def test_callback(chunk: StreamingChunk):
streaming_chunks.append(chunk)

Expand All @@ -469,7 +433,7 @@ def test_callback(chunk: StreamingChunk):
{"contentBlockDelta": {"delta": {"toolUse": {"input": '"test"}'}}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "complete"}},
{"metadata": {"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}}}
{"metadata": {"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}}},
]

replies = generator.process_streaming_response(events, test_callback)
Expand All @@ -485,11 +449,7 @@ def test_callback(chunk: StreamingChunk):
assert replies[0].content == "Let me help you."
assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0"
assert replies[0].meta["finish_reason"] == "complete"
assert replies[0].meta["usage"] == {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}

# Check tool use reply
tool_content = json.loads(replies[1].content)
Expand Down

0 comments on commit aa7c5c0

Please sign in to comment.