diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py index eca81c3f1..40ba0bc67 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py @@ -150,7 +150,12 @@ class AmazonTitanAdapter(BedrockModelAdapter): """ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: - default_params = {"maxTokenCount": self.max_length, "stopSequences": None, "temperature": None, "topP": None} + default_params = { + "maxTokenCount": self.max_length, + "stopSequences": None, + "temperature": None, + "topP": None, + } params = self._get_params(inference_kwargs, default_params) body = {"inputText": prompt, "textGenerationConfig": params} @@ -170,7 +175,11 @@ class MetaLlama2ChatAdapter(BedrockModelAdapter): """ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: - default_params = {"max_gen_len": self.max_length, "temperature": None, "top_p": None} + default_params = { + "max_gen_len": self.max_length, + "temperature": None, + "top_p": None, + } params = self._get_params(inference_kwargs, default_params) body = {"prompt": prompt, **params} diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py index 53c28ad1d..aa8a3f6e4 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py @@ -10,7 +10,10 @@ class AmazonBedrockError(Exception): `AmazonBedrockError.message` will exist and have the expected content. """ - def __init__(self, message: Optional[str] = None): + def __init__( + self, + message: Optional[str] = None, + ): super().__init__() if message: self.message = message diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 48f22f59b..4c43c9a09 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -15,8 +15,16 @@ CohereCommandAdapter, MetaLlama2ChatAdapter, ) -from .errors import AmazonBedrockConfigurationError, AmazonBedrockInferenceError, AWSConfigurationError -from .handlers import DefaultPromptHandler, DefaultTokenStreamingHandler, TokenStreamingHandler +from .errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, + AWSConfigurationError, +) +from .handlers import ( + DefaultPromptHandler, + DefaultTokenStreamingHandler, + TokenStreamingHandler, +) logger = logging.getLogger(__name__) @@ -37,7 +45,7 @@ class AmazonBedrockGenerator: Usage example: ```python - from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator + from amazon_bedrock_haystack.generators.amazon_bedrock import AmazonBedrockGenerator generator = AmazonBedrockGenerator( model="anthropic.claude-v2", @@ -104,7 +112,9 @@ def __init__( # It is hard to determine which tokenizer to use for the SageMaker model # so we use GPT2 tokenizer which will likely provide good token count approximation self.prompt_handler = DefaultPromptHandler( - tokenizer="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100 + model="gpt2", + model_max_length=model_max_length, + max_length=self.max_length or 100, ) model_adapter_cls = self.get_model_adapter(model=model) @@ -193,7 +203,10 @@ def invoke(self, *args, **kwargs): try: if stream: response = self.client.invoke_model_with_response_stream( - body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" + body=json.dumps(body), + modelId=self.model, + accept="application/json", + contentType="application/json", ) response_stream = response["body"] handler: TokenStreamingHandler = kwargs.get( @@ -203,7 +216,10 @@ def invoke(self, *args, **kwargs): responses = self.model_adapter.get_stream_responses(stream=response_stream, stream_handler=handler) else: response = self.client.invoke_model( - body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" + body=json.dumps(body), + modelId=self.model, + accept="application/json", + contentType="application/json", ) response_body = json.loads(response.get("body").read().decode("utf-8")) responses = self.model_adapter.get_responses(response_body=response_body) @@ -280,7 +296,11 @@ def to_dict(self) -> Dict[str, Any]: Serialize this component to a dictionary. :return: The serialized component as a dictionary. """ - return default_to_dict(self, model=self.model, max_length=self.max_length) + return default_to_dict( + self, + model=self.model, + max_length=self.max_length, + ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockGenerator": diff --git a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py index 6be07b06a..b08e9dfd5 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py @@ -56,7 +56,10 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): expected_dict = { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", - "init_parameters": {"model": "anthropic.claude-v2", "max_length": 99}, + "init_parameters": { + "model": "anthropic.claude-v2", + "max_length": 99, + }, } assert generator.to_dict() == expected_dict @@ -70,7 +73,10 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): generator = AmazonBedrockGenerator.from_dict( { "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", - "init_parameters": {"model": "anthropic.claude-v2", "max_length": 99}, + "init_parameters": { + "model": "anthropic.claude-v2", + "max_length": 99, + }, } ) @@ -175,7 +181,9 @@ def test_short_prompt_is_not_truncated(mock_boto3_session): with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): layer = AmazonBedrockGenerator( - "anthropic.claude-v2", max_length=max_length_generated_text, model_max_length=total_model_max_length + "anthropic.claude-v2", + max_length=max_length_generated_text, + model_max_length=total_model_max_length, ) prompt_after_resize = layer._ensure_token_limit(mock_prompt_text) @@ -208,7 +216,9 @@ def test_long_prompt_is_truncated(mock_boto3_session): with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): layer = AmazonBedrockGenerator( - "anthropic.claude-v2", max_length=max_length_generated_text, model_max_length=total_model_max_length + "anthropic.claude-v2", + max_length=max_length_generated_text, + model_max_length=total_model_max_length, ) prompt_after_resize = layer._ensure_token_limit(long_prompt_text) @@ -228,7 +238,10 @@ def test_supports_for_valid_aws_configuration(): "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ): - supported = AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_real_profile") + supported = AmazonBedrockGenerator.supports( + model="anthropic.claude-v2", + aws_profile_name="some_real_profile", + ) args, kwargs = mock_session.client("bedrock").list_foundation_models.call_args assert kwargs["byOutputModality"] == "TEXT" @@ -240,7 +253,10 @@ def test_supports_raises_on_invalid_aws_profile_name(): with patch("boto3.Session") as mock_boto3_session: mock_boto3_session.side_effect = BotoCoreError() with pytest.raises(AmazonBedrockConfigurationError, match="Failed to initialize the session"): - AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_fake_profile") + AmazonBedrockGenerator.supports( + model="anthropic.claude-v2", + aws_profile_name="some_fake_profile", + ) @pytest.mark.unit @@ -253,7 +269,10 @@ def test_supports_for_invalid_bedrock_config(): "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): - AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_real_profile") + AmazonBedrockGenerator.supports( + model="anthropic.claude-v2", + aws_profile_name="some_real_profile", + ) @pytest.mark.unit @@ -266,7 +285,10 @@ def test_supports_for_invalid_bedrock_config_error_on_list_models(): "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): - AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_real_profile") + AmazonBedrockGenerator.supports( + model="anthropic.claude-v2", + aws_profile_name="some_real_profile", + ) @pytest.mark.unit @@ -296,7 +318,9 @@ def test_supports_with_stream_true_for_model_that_supports_streaming(): return_value=mock_session, ): supported = AmazonBedrockGenerator.supports( - model="anthropic.claude-v2", aws_profile_name="some_real_profile", stream=True + model="anthropic.claude-v2", + aws_profile_name="some_real_profile", + stream=True, ) assert supported @@ -313,8 +337,15 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): with patch( "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, - ), pytest.raises(AmazonBedrockConfigurationError, match="The model ai21.j2-mid-v1 doesn't support streaming."): - AmazonBedrockGenerator.supports(model="ai21.j2-mid-v1", aws_profile_name="some_real_profile", stream=True) + ), pytest.raises( + AmazonBedrockConfigurationError, + match="The model ai21.j2-mid-v1 doesn't support streaming.", + ): + AmazonBedrockGenerator.supports( + model="ai21.j2-mid-v1", + aws_profile_name="some_real_profile", + stream=True, + ) @pytest.mark.unit @@ -634,9 +665,15 @@ def test_get_responses_leading_whitespace(self) -> None: def test_get_responses_multiple_responses(self) -> None: adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) response_body = { - "generations": [{"text": "This is a single response."}, {"text": "This is a second response."}] + "generations": [ + {"text": "This is a single response."}, + {"text": "This is a second response."}, + ] } - expected_responses = ["This is a single response.", "This is a second response."] + expected_responses = [ + "This is a single response.", + "This is a second response.", + ] assert adapter.get_responses(response_body) == expected_responses def test_get_stream_responses(self) -> None: @@ -817,7 +854,10 @@ def test_get_responses_multiple_responses(self) -> None: {"data": {"text": "This is a second response."}}, ] } - expected_responses = ["This is a single response.", "This is a second response."] + expected_responses = [ + "This is a single response.", + "This is a second response.", + ] assert adapter.get_responses(response_body) == expected_responses @@ -825,7 +865,10 @@ class TestAmazonTitanAdapter: def test_prepare_body_with_default_params(self) -> None: layer = AmazonTitanAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" - expected_body = {"inputText": "Hello, how are you?", "textGenerationConfig": {"maxTokenCount": 99}} + expected_body = { + "inputText": "Hello, how are you?", + "textGenerationConfig": {"maxTokenCount": 99}, + } body = layer.prepare_body(prompt) @@ -921,9 +964,15 @@ def test_get_responses_leading_whitespace(self) -> None: def test_get_responses_multiple_responses(self) -> None: adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) response_body = { - "results": [{"outputText": "This is a single response."}, {"outputText": "This is a second response."}] + "results": [ + {"outputText": "This is a single response."}, + {"outputText": "This is a second response."}, + ] } - expected_responses = ["This is a single response.", "This is a second response."] + expected_responses = [ + "This is a single response.", + "This is a second response.", + ] assert adapter.get_responses(response_body) == expected_responses def test_get_stream_responses(self) -> None: @@ -982,19 +1031,40 @@ def test_prepare_body_with_default_params(self) -> None: def test_prepare_body_with_custom_inference_params(self) -> None: layer = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) prompt = "Hello, how are you?" - expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.8} + expected_body = { + "prompt": "Hello, how are you?", + "max_gen_len": 50, + "temperature": 0.7, + "top_p": 0.8, + } - body = layer.prepare_body(prompt, temperature=0.7, top_p=0.8, max_gen_len=50, unknown_arg="unknown_value") + body = layer.prepare_body( + prompt, + temperature=0.7, + top_p=0.8, + max_gen_len=50, + unknown_arg="unknown_value", + ) assert body == expected_body def test_prepare_body_with_model_kwargs(self) -> None: layer = MetaLlama2ChatAdapter( - model_kwargs={"temperature": 0.7, "top_p": 0.8, "max_gen_len": 50, "unknown_arg": "unknown_value"}, + model_kwargs={ + "temperature": 0.7, + "top_p": 0.8, + "max_gen_len": 50, + "unknown_arg": "unknown_value", + }, max_length=99, ) prompt = "Hello, how are you?" - expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.8} + expected_body = { + "prompt": "Hello, how are you?", + "max_gen_len": 50, + "temperature": 0.7, + "top_p": 0.8, + } body = layer.prepare_body(prompt) @@ -1002,10 +1072,21 @@ def test_prepare_body_with_model_kwargs(self) -> None: def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: layer = MetaLlama2ChatAdapter( - model_kwargs={"temperature": 0.6, "top_p": 0.7, "top_k": 4, "max_gen_len": 49}, max_length=99 + model_kwargs={ + "temperature": 0.6, + "top_p": 0.7, + "top_k": 4, + "max_gen_len": 49, + }, + max_length=99, ) prompt = "Hello, how are you?" - expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.7} + expected_body = { + "prompt": "Hello, how are you?", + "max_gen_len": 50, + "temperature": 0.7, + "top_p": 0.7, + } body = layer.prepare_body(prompt, temperature=0.7, max_gen_len=50)