diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py b/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py index 08b970bad..dda84fe14 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py +++ b/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py @@ -48,7 +48,7 @@ class AmazonBedrockGenerator: from amazon_bedrock_haystack.generators.amazon_bedrock import AmazonBedrockGenerator generator = AmazonBedrockGenerator( - model_name="anthropic.claude-v2", + model="anthropic.claude-v2", max_length=99, aws_access_key_id="...", aws_secret_access_key="...", @@ -71,7 +71,7 @@ class AmazonBedrockGenerator: def __init__( self, - model_name: str, + model: str, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, @@ -80,10 +80,10 @@ def __init__( max_length: Optional[int] = 100, **kwargs, ): - if not model_name: - msg = "model_name cannot be None or empty string" + if not model: + msg = "'model' cannot be None or empty string" raise ValueError(msg) - self.model_name = model_name + self.model = model self.max_length = max_length try: @@ -112,14 +112,14 @@ def __init__( # It is hard to determine which tokenizer to use for the SageMaker model # so we use GPT2 tokenizer which will likely provide good token count approximation self.prompt_handler = DefaultPromptHandler( - model_name="gpt2", + model="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100, ) - model_adapter_cls = self.get_model_adapter(model_name=model_name) + model_adapter_cls = self.get_model_adapter(model=model) if not model_adapter_cls: - msg = f"AmazonBedrockGenerator doesn't support the model {model_name}." + msg = f"AmazonBedrockGenerator doesn't support the model {model}." raise AmazonBedrockConfigurationError(msg) self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length) @@ -146,8 +146,8 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union return str(resize_info["resized_prompt"]) @classmethod - def supports(cls, model_name, **kwargs): - model_supported = cls.get_model_adapter(model_name) is not None + def supports(cls, model, **kwargs): + model_supported = cls.get_model_adapter(model) is not None if not model_supported or not cls.aws_configured(**kwargs): return False @@ -170,19 +170,19 @@ def supports(cls, model_name, **kwargs): ) raise AmazonBedrockConfigurationError(msg) from exception - model_available = model_name in available_model_ids + model_available = model in available_model_ids if not model_available: msg = ( - f"The model {model_name} is not available in Amazon Bedrock. " + f"The model {model} is not available in Amazon Bedrock. " f"Make sure the model you want to use is available in the configured AWS region and " f"you have access." ) raise AmazonBedrockConfigurationError(msg) stream: bool = kwargs.get("stream", False) - model_supports_streaming = model_name in model_ids_supporting_streaming + model_supports_streaming = model in model_ids_supporting_streaming if stream and not model_supports_streaming: - msg = f"The model {model_name} doesn't support streaming. Remove the `stream` parameter." + msg = f"The model {model} doesn't support streaming. Remove the `stream` parameter." raise AmazonBedrockConfigurationError(msg) return model_supported @@ -194,7 +194,7 @@ def invoke(self, *args, **kwargs): if not prompt or not isinstance(prompt, (str, list)): msg = ( - f"The model {self.model_name} requires a valid prompt, but currently, it has no prompt. " + f"The model {self.model} requires a valid prompt, but currently, it has no prompt. " f"Make sure to provide a prompt in the format that the model expects." ) raise ValueError(msg) @@ -204,7 +204,7 @@ def invoke(self, *args, **kwargs): if stream: response = self.client.invoke_model_with_response_stream( body=json.dumps(body), - modelId=self.model_name, + modelId=self.model, accept="application/json", contentType="application/json", ) @@ -217,7 +217,7 @@ def invoke(self, *args, **kwargs): else: response = self.client.invoke_model( body=json.dumps(body), - modelId=self.model_name, + modelId=self.model, accept="application/json", contentType="application/json", ) @@ -225,7 +225,7 @@ def invoke(self, *args, **kwargs): responses = self.model_adapter.get_responses(response_body=response_body) except ClientError as exception: msg = ( - f"Could not connect to Amazon Bedrock model {self.model_name}. " + f"Could not connect to Amazon Bedrock model {self.model}. " f"Make sure your AWS environment is configured correctly, " f"the model is available in the configured AWS region, and you have access." ) @@ -238,9 +238,9 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): return {"replies": self.invoke(prompt=prompt, **(generation_kwargs or {}))} @classmethod - def get_model_adapter(cls, model_name: str) -> Optional[Type[BedrockModelAdapter]]: + def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]: for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): - if re.fullmatch(pattern, model_name): + if re.fullmatch(pattern, model): return adapter return None @@ -298,7 +298,7 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - model_name=self.model_name, + model=self.model, max_length=self.max_length, ) diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py b/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py index ddfe686f8..56dcb24d3 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py +++ b/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py @@ -10,8 +10,8 @@ class DefaultPromptHandler: are within the model_max_length. """ - def __init__(self, model_name: str, model_max_length: int, max_length: int = 100): - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + def __init__(self, model: str, model_max_length: int, max_length: int = 100): + self.tokenizer = AutoTokenizer.from_pretrained(model) self.tokenizer.model_max_length = model_max_length self.model_max_length = model_max_length self.max_length = max_length diff --git a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py index a2a484340..a05c95ba3 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py @@ -45,7 +45,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): Test that the to_dict method returns the correct dictionary without aws credentials """ generator = AmazonBedrockGenerator( - model_name="anthropic.claude-v2", + model="anthropic.claude-v2", max_length=99, aws_access_key_id="some_fake_id", aws_secret_access_key="some_fake_key", @@ -57,7 +57,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): expected_dict = { "type": "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator", "init_parameters": { - "model_name": "anthropic.claude-v2", + "model": "anthropic.claude-v2", "max_length": 99, }, } @@ -74,14 +74,14 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): { "type": "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator", "init_parameters": { - "model_name": "anthropic.claude-v2", + "model": "anthropic.claude-v2", "max_length": 99, }, } ) assert generator.max_length == 99 - assert generator.model_name == "anthropic.claude-v2" + assert generator.model == "anthropic.claude-v2" @pytest.mark.unit @@ -91,7 +91,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session): """ layer = AmazonBedrockGenerator( - model_name="anthropic.claude-v2", + model="anthropic.claude-v2", max_length=99, aws_access_key_id="some_fake_id", aws_secret_access_key="some_fake_key", @@ -101,7 +101,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session): ) assert layer.max_length == 99 - assert layer.model_name == "anthropic.claude-v2" + assert layer.model == "anthropic.claude-v2" assert layer.prompt_handler is not None assert layer.prompt_handler.model_max_length == 4096 @@ -124,7 +124,7 @@ def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_ """ Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2 """ - layer = AmazonBedrockGenerator(model_name="anthropic.claude-v2", prompt_handler=mock_prompt_handler) + layer = AmazonBedrockGenerator(model="anthropic.claude-v2", prompt_handler=mock_prompt_handler) assert layer.prompt_handler is not None assert layer.prompt_handler.model_max_length == 4096 @@ -136,18 +136,18 @@ def test_constructor_with_model_kwargs(mock_auto_tokenizer, mock_boto3_session): """ model_kwargs = {"temperature": 0.7} - layer = AmazonBedrockGenerator(model_name="anthropic.claude-v2", **model_kwargs) + layer = AmazonBedrockGenerator(model="anthropic.claude-v2", **model_kwargs) assert "temperature" in layer.model_adapter.model_kwargs assert layer.model_adapter.model_kwargs["temperature"] == 0.7 @pytest.mark.unit -def test_constructor_with_empty_model_name(): +def test_constructor_with_empty_model(): """ - Test that the constructor raises an error when the model_name is empty + Test that the constructor raises an error when the model is empty """ with pytest.raises(ValueError, match="cannot be None or empty string"): - AmazonBedrockGenerator(model_name="") + AmazonBedrockGenerator(model="") @pytest.mark.unit @@ -155,7 +155,7 @@ def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session): """ Test invoke raises an error if no prompt is provided """ - layer = AmazonBedrockGenerator(model_name="anthropic.claude-v2") + layer = AmazonBedrockGenerator(model="anthropic.claude-v2") with pytest.raises(ValueError, match="The model anthropic.claude-v2 requires a valid prompt."): layer.invoke() @@ -239,7 +239,7 @@ def test_supports_for_valid_aws_configuration(): return_value=mock_session, ): supported = AmazonBedrockGenerator.supports( - model_name="anthropic.claude-v2", + model="anthropic.claude-v2", aws_profile_name="some_real_profile", ) args, kwargs = mock_session.client("bedrock").list_foundation_models.call_args @@ -254,7 +254,7 @@ def test_supports_raises_on_invalid_aws_profile_name(): mock_boto3_session.side_effect = BotoCoreError() with pytest.raises(AmazonBedrockConfigurationError, match="Failed to initialize the session"): AmazonBedrockGenerator.supports( - model_name="anthropic.claude-v2", + model="anthropic.claude-v2", aws_profile_name="some_fake_profile", ) @@ -270,7 +270,7 @@ def test_supports_for_invalid_bedrock_config(): return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): AmazonBedrockGenerator.supports( - model_name="anthropic.claude-v2", + model="anthropic.claude-v2", aws_profile_name="some_real_profile", ) @@ -286,21 +286,21 @@ def test_supports_for_invalid_bedrock_config_error_on_list_models(): return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): AmazonBedrockGenerator.supports( - model_name="anthropic.claude-v2", + model="anthropic.claude-v2", aws_profile_name="some_real_profile", ) @pytest.mark.unit def test_supports_for_no_aws_params(): - supported = AmazonBedrockGenerator.supports(model_name="anthropic.claude-v2") + supported = AmazonBedrockGenerator.supports(model="anthropic.claude-v2") assert supported is False @pytest.mark.unit def test_supports_for_unknown_model(): - supported = AmazonBedrockGenerator.supports(model_name="unknown_model", aws_profile_name="some_real_profile") + supported = AmazonBedrockGenerator.supports(model="unknown_model", aws_profile_name="some_real_profile") assert supported is False @@ -318,7 +318,7 @@ def test_supports_with_stream_true_for_model_that_supports_streaming(): return_value=mock_session, ): supported = AmazonBedrockGenerator.supports( - model_name="anthropic.claude-v2", + model="anthropic.claude-v2", aws_profile_name="some_real_profile", stream=True, ) @@ -342,7 +342,7 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): match="The model ai21.j2-mid-v1 doesn't support streaming.", ): AmazonBedrockGenerator.supports( - model_name="ai21.j2-mid-v1", + model="ai21.j2-mid-v1", aws_profile_name="some_real_profile", stream=True, ) @@ -350,7 +350,7 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): @pytest.mark.unit @pytest.mark.parametrize( - "model_name, expected_model_adapter", + "model, expected_model_adapter", [ ("anthropic.claude-v1", AnthropicClaudeAdapter), ("anthropic.claude-v2", AnthropicClaudeAdapter), @@ -372,11 +372,11 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): ("unknown_model", None), ], ) -def test_get_model_adapter(model_name: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): +def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): """ - Test that the correct model adapter is returned for a given model_name + Test that the correct model adapter is returned for a given model """ - model_adapter = AmazonBedrockGenerator.get_model_adapter(model_name=model_name) + model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model) assert model_adapter == expected_model_adapter