diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 0ad78b74b..b6b358a7e 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -248,3 +248,24 @@ def test_default_inference_params(self, model_name): assert response["replies"][0].content assert ChatMessage.is_from(response["replies"][0], ChatRole.ASSISTANT) assert "paris" in response["replies"][0].content.lower() + + @pytest.mark.parametrize( + "model_name", ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] + ) + @pytest.mark.integration + def test_default_inference_with_streaming(self, model_name): + + callback_called = False + + def streaming_callback_verifier(chunk): + nonlocal callback_called + callback_called = True + + messages = [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + + client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback_verifier) + client.run(messages) + assert callback_called