From 839adda8c446b22c536a3fee97a88fbdcf702863 Mon Sep 17 00:00:00 2001 From: alperkaya Date: Sat, 28 Sep 2024 10:12:28 +0200 Subject: [PATCH] remove stream arg --- .../generators/amazon_bedrock/chat/chat_generator.py | 4 ++-- integrations/amazon_bedrock/tests/test_chat_generator.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 988452a97..719130f0b 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -173,7 +173,7 @@ def run( generation_kwargs = generation_kwargs.copy() streaming_callback = streaming_callback or self.streaming_callback - generation_kwargs["stream"] = streaming_callback is not None + is_streaming_enabled = streaming_callback is not None # check if the prompt is a list of ChatMessage objects if not ( @@ -188,7 +188,7 @@ def run( messages=messages, **{"stop_words": self.stop_words, **generation_kwargs} ) try: - if streaming_callback: + if is_streaming_enabled: response = self.client.invoke_model_with_response_stream( body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" ) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index a455d2c93..df3fb4381 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -246,7 +246,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): mock_ensure_token_limit.assert_not_called(), # Check the prompt passed to prepare_body - generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[], stream=False) + generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[]) @pytest.mark.parametrize(