Skip to content

Commit

Permalink
chore!: Rename model_name to model in AmazonBedrockGenerator (#220
Browse files Browse the repository at this point in the history
)

* rename model_name to model

* fix tests
  • Loading branch information
ZanSara authored Jan 17, 2024
1 parent ef9dd3f commit 95effa1
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class AmazonBedrockGenerator:
from amazon_bedrock_haystack.generators.amazon_bedrock import AmazonBedrockGenerator
generator = AmazonBedrockGenerator(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
max_length=99,
aws_access_key_id="...",
aws_secret_access_key="...",
Expand All @@ -71,7 +71,7 @@ class AmazonBedrockGenerator:

def __init__(
self,
model_name: str,
model: str,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
Expand All @@ -80,10 +80,10 @@ def __init__(
max_length: Optional[int] = 100,
**kwargs,
):
if not model_name:
msg = "model_name cannot be None or empty string"
if not model:
msg = "'model' cannot be None or empty string"
raise ValueError(msg)
self.model_name = model_name
self.model = model
self.max_length = max_length

try:
Expand Down Expand Up @@ -112,14 +112,14 @@ def __init__(
# It is hard to determine which tokenizer to use for the SageMaker model
# so we use GPT2 tokenizer which will likely provide good token count approximation
self.prompt_handler = DefaultPromptHandler(
model_name="gpt2",
model="gpt2",
model_max_length=model_max_length,
max_length=self.max_length or 100,
)

model_adapter_cls = self.get_model_adapter(model_name=model_name)
model_adapter_cls = self.get_model_adapter(model=model)
if not model_adapter_cls:
msg = f"AmazonBedrockGenerator doesn't support the model {model_name}."
msg = f"AmazonBedrockGenerator doesn't support the model {model}."
raise AmazonBedrockConfigurationError(msg)
self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length)

Expand All @@ -146,8 +146,8 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union
return str(resize_info["resized_prompt"])

@classmethod
def supports(cls, model_name, **kwargs):
model_supported = cls.get_model_adapter(model_name) is not None
def supports(cls, model, **kwargs):
model_supported = cls.get_model_adapter(model) is not None
if not model_supported or not cls.aws_configured(**kwargs):
return False

Expand All @@ -170,19 +170,19 @@ def supports(cls, model_name, **kwargs):
)
raise AmazonBedrockConfigurationError(msg) from exception

model_available = model_name in available_model_ids
model_available = model in available_model_ids
if not model_available:
msg = (
f"The model {model_name} is not available in Amazon Bedrock. "
f"The model {model} is not available in Amazon Bedrock. "
f"Make sure the model you want to use is available in the configured AWS region and "
f"you have access."
)
raise AmazonBedrockConfigurationError(msg)

stream: bool = kwargs.get("stream", False)
model_supports_streaming = model_name in model_ids_supporting_streaming
model_supports_streaming = model in model_ids_supporting_streaming
if stream and not model_supports_streaming:
msg = f"The model {model_name} doesn't support streaming. Remove the `stream` parameter."
msg = f"The model {model} doesn't support streaming. Remove the `stream` parameter."
raise AmazonBedrockConfigurationError(msg)

return model_supported
Expand All @@ -194,7 +194,7 @@ def invoke(self, *args, **kwargs):

if not prompt or not isinstance(prompt, (str, list)):
msg = (
f"The model {self.model_name} requires a valid prompt, but currently, it has no prompt. "
f"The model {self.model} requires a valid prompt, but currently, it has no prompt. "
f"Make sure to provide a prompt in the format that the model expects."
)
raise ValueError(msg)
Expand All @@ -204,7 +204,7 @@ def invoke(self, *args, **kwargs):
if stream:
response = self.client.invoke_model_with_response_stream(
body=json.dumps(body),
modelId=self.model_name,
modelId=self.model,
accept="application/json",
contentType="application/json",
)
Expand All @@ -217,15 +217,15 @@ def invoke(self, *args, **kwargs):
else:
response = self.client.invoke_model(
body=json.dumps(body),
modelId=self.model_name,
modelId=self.model,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read().decode("utf-8"))
responses = self.model_adapter.get_responses(response_body=response_body)
except ClientError as exception:
msg = (
f"Could not connect to Amazon Bedrock model {self.model_name}. "
f"Could not connect to Amazon Bedrock model {self.model}. "
f"Make sure your AWS environment is configured correctly, "
f"the model is available in the configured AWS region, and you have access."
)
Expand All @@ -238,9 +238,9 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
return {"replies": self.invoke(prompt=prompt, **(generation_kwargs or {}))}

@classmethod
def get_model_adapter(cls, model_name: str) -> Optional[Type[BedrockModelAdapter]]:
def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]:
for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items():
if re.fullmatch(pattern, model_name):
if re.fullmatch(pattern, model):
return adapter
return None

Expand Down Expand Up @@ -298,7 +298,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
max_length=self.max_length,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class DefaultPromptHandler:
are within the model_max_length.
"""

def __init__(self, model_name: str, model_max_length: int, max_length: int = 100):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def __init__(self, model: str, model_max_length: int, max_length: int = 100):
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.tokenizer.model_max_length = model_max_length
self.model_max_length = model_max_length
self.max_length = max_length
Expand Down
48 changes: 24 additions & 24 deletions integrations/amazon_bedrock/tests/test_amazon_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session):
Test that the to_dict method returns the correct dictionary without aws credentials
"""
generator = AmazonBedrockGenerator(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
max_length=99,
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
Expand All @@ -57,7 +57,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session):
expected_dict = {
"type": "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator",
"init_parameters": {
"model_name": "anthropic.claude-v2",
"model": "anthropic.claude-v2",
"max_length": 99,
},
}
Expand All @@ -74,14 +74,14 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session):
{
"type": "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator",
"init_parameters": {
"model_name": "anthropic.claude-v2",
"model": "anthropic.claude-v2",
"max_length": 99,
},
}
)

assert generator.max_length == 99
assert generator.model_name == "anthropic.claude-v2"
assert generator.model == "anthropic.claude-v2"


@pytest.mark.unit
Expand All @@ -91,7 +91,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):
"""

layer = AmazonBedrockGenerator(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
max_length=99,
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
Expand All @@ -101,7 +101,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):
)

assert layer.max_length == 99
assert layer.model_name == "anthropic.claude-v2"
assert layer.model == "anthropic.claude-v2"

assert layer.prompt_handler is not None
assert layer.prompt_handler.model_max_length == 4096
Expand All @@ -124,7 +124,7 @@ def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_
"""
Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2
"""
layer = AmazonBedrockGenerator(model_name="anthropic.claude-v2", prompt_handler=mock_prompt_handler)
layer = AmazonBedrockGenerator(model="anthropic.claude-v2", prompt_handler=mock_prompt_handler)
assert layer.prompt_handler is not None
assert layer.prompt_handler.model_max_length == 4096

Expand All @@ -136,26 +136,26 @@ def test_constructor_with_model_kwargs(mock_auto_tokenizer, mock_boto3_session):
"""
model_kwargs = {"temperature": 0.7}

layer = AmazonBedrockGenerator(model_name="anthropic.claude-v2", **model_kwargs)
layer = AmazonBedrockGenerator(model="anthropic.claude-v2", **model_kwargs)
assert "temperature" in layer.model_adapter.model_kwargs
assert layer.model_adapter.model_kwargs["temperature"] == 0.7


@pytest.mark.unit
def test_constructor_with_empty_model_name():
def test_constructor_with_empty_model():
"""
Test that the constructor raises an error when the model_name is empty
Test that the constructor raises an error when the model is empty
"""
with pytest.raises(ValueError, match="cannot be None or empty string"):
AmazonBedrockGenerator(model_name="")
AmazonBedrockGenerator(model="")


@pytest.mark.unit
def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session):
"""
Test invoke raises an error if no prompt is provided
"""
layer = AmazonBedrockGenerator(model_name="anthropic.claude-v2")
layer = AmazonBedrockGenerator(model="anthropic.claude-v2")
with pytest.raises(ValueError, match="The model anthropic.claude-v2 requires a valid prompt."):
layer.invoke()

Expand Down Expand Up @@ -239,7 +239,7 @@ def test_supports_for_valid_aws_configuration():
return_value=mock_session,
):
supported = AmazonBedrockGenerator.supports(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
)
args, kwargs = mock_session.client("bedrock").list_foundation_models.call_args
Expand All @@ -254,7 +254,7 @@ def test_supports_raises_on_invalid_aws_profile_name():
mock_boto3_session.side_effect = BotoCoreError()
with pytest.raises(AmazonBedrockConfigurationError, match="Failed to initialize the session"):
AmazonBedrockGenerator.supports(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
aws_profile_name="some_fake_profile",
)

Expand All @@ -270,7 +270,7 @@ def test_supports_for_invalid_bedrock_config():
return_value=mock_session,
), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."):
AmazonBedrockGenerator.supports(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
)

Expand All @@ -286,21 +286,21 @@ def test_supports_for_invalid_bedrock_config_error_on_list_models():
return_value=mock_session,
), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."):
AmazonBedrockGenerator.supports(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
)


@pytest.mark.unit
def test_supports_for_no_aws_params():
supported = AmazonBedrockGenerator.supports(model_name="anthropic.claude-v2")
supported = AmazonBedrockGenerator.supports(model="anthropic.claude-v2")

assert supported is False


@pytest.mark.unit
def test_supports_for_unknown_model():
supported = AmazonBedrockGenerator.supports(model_name="unknown_model", aws_profile_name="some_real_profile")
supported = AmazonBedrockGenerator.supports(model="unknown_model", aws_profile_name="some_real_profile")

assert supported is False

Expand All @@ -318,7 +318,7 @@ def test_supports_with_stream_true_for_model_that_supports_streaming():
return_value=mock_session,
):
supported = AmazonBedrockGenerator.supports(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
stream=True,
)
Expand All @@ -342,15 +342,15 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming():
match="The model ai21.j2-mid-v1 doesn't support streaming.",
):
AmazonBedrockGenerator.supports(
model_name="ai21.j2-mid-v1",
model="ai21.j2-mid-v1",
aws_profile_name="some_real_profile",
stream=True,
)


@pytest.mark.unit
@pytest.mark.parametrize(
"model_name, expected_model_adapter",
"model, expected_model_adapter",
[
("anthropic.claude-v1", AnthropicClaudeAdapter),
("anthropic.claude-v2", AnthropicClaudeAdapter),
Expand All @@ -372,11 +372,11 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming():
("unknown_model", None),
],
)
def test_get_model_adapter(model_name: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]):
def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]):
"""
Test that the correct model adapter is returned for a given model_name
Test that the correct model adapter is returned for a given model
"""
model_adapter = AmazonBedrockGenerator.get_model_adapter(model_name=model_name)
model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model)
assert model_adapter == expected_model_adapter


Expand Down

0 comments on commit 95effa1

Please sign in to comment.