Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore!: Rename model_name to model in AmazonBedrockGenerator #220

Merged
merged 2 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class AmazonBedrockGenerator:
from amazon_bedrock_haystack.generators.amazon_bedrock import AmazonBedrockGenerator
generator = AmazonBedrockGenerator(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
max_length=99,
aws_access_key_id="...",
aws_secret_access_key="...",
Expand All @@ -71,7 +71,7 @@ class AmazonBedrockGenerator:

def __init__(
self,
model_name: str,
model: str,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
Expand All @@ -80,10 +80,10 @@ def __init__(
max_length: Optional[int] = 100,
**kwargs,
):
if not model_name:
msg = "model_name cannot be None or empty string"
if not model:
msg = "'model' cannot be None or empty string"
raise ValueError(msg)
self.model_name = model_name
self.model = model
self.max_length = max_length

try:
Expand Down Expand Up @@ -112,14 +112,14 @@ def __init__(
# It is hard to determine which tokenizer to use for the SageMaker model
# so we use GPT2 tokenizer which will likely provide good token count approximation
self.prompt_handler = DefaultPromptHandler(
model_name="gpt2",
model="gpt2",
model_max_length=model_max_length,
max_length=self.max_length or 100,
)

model_adapter_cls = self.get_model_adapter(model_name=model_name)
model_adapter_cls = self.get_model_adapter(model=model)
if not model_adapter_cls:
msg = f"AmazonBedrockGenerator doesn't support the model {model_name}."
msg = f"AmazonBedrockGenerator doesn't support the model {model}."
raise AmazonBedrockConfigurationError(msg)
self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length)

Expand All @@ -146,8 +146,8 @@ def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union
return str(resize_info["resized_prompt"])

@classmethod
def supports(cls, model_name, **kwargs):
model_supported = cls.get_model_adapter(model_name) is not None
def supports(cls, model, **kwargs):
model_supported = cls.get_model_adapter(model) is not None
if not model_supported or not cls.aws_configured(**kwargs):
return False

Expand All @@ -170,19 +170,19 @@ def supports(cls, model_name, **kwargs):
)
raise AmazonBedrockConfigurationError(msg) from exception

model_available = model_name in available_model_ids
model_available = model in available_model_ids
if not model_available:
msg = (
f"The model {model_name} is not available in Amazon Bedrock. "
f"The model {model} is not available in Amazon Bedrock. "
f"Make sure the model you want to use is available in the configured AWS region and "
f"you have access."
)
raise AmazonBedrockConfigurationError(msg)

stream: bool = kwargs.get("stream", False)
model_supports_streaming = model_name in model_ids_supporting_streaming
model_supports_streaming = model in model_ids_supporting_streaming
if stream and not model_supports_streaming:
msg = f"The model {model_name} doesn't support streaming. Remove the `stream` parameter."
msg = f"The model {model} doesn't support streaming. Remove the `stream` parameter."
raise AmazonBedrockConfigurationError(msg)

return model_supported
Expand All @@ -194,7 +194,7 @@ def invoke(self, *args, **kwargs):

if not prompt or not isinstance(prompt, (str, list)):
msg = (
f"The model {self.model_name} requires a valid prompt, but currently, it has no prompt. "
f"The model {self.model} requires a valid prompt, but currently, it has no prompt. "
f"Make sure to provide a prompt in the format that the model expects."
)
raise ValueError(msg)
Expand All @@ -204,7 +204,7 @@ def invoke(self, *args, **kwargs):
if stream:
response = self.client.invoke_model_with_response_stream(
body=json.dumps(body),
modelId=self.model_name,
modelId=self.model,
accept="application/json",
contentType="application/json",
)
Expand All @@ -217,15 +217,15 @@ def invoke(self, *args, **kwargs):
else:
response = self.client.invoke_model(
body=json.dumps(body),
modelId=self.model_name,
modelId=self.model,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read().decode("utf-8"))
responses = self.model_adapter.get_responses(response_body=response_body)
except ClientError as exception:
msg = (
f"Could not connect to Amazon Bedrock model {self.model_name}. "
f"Could not connect to Amazon Bedrock model {self.model}. "
f"Make sure your AWS environment is configured correctly, "
f"the model is available in the configured AWS region, and you have access."
)
Expand All @@ -238,9 +238,9 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
return {"replies": self.invoke(prompt=prompt, **(generation_kwargs or {}))}

@classmethod
def get_model_adapter(cls, model_name: str) -> Optional[Type[BedrockModelAdapter]]:
def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelAdapter]]:
for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items():
if re.fullmatch(pattern, model_name):
if re.fullmatch(pattern, model):
return adapter
return None

Expand Down Expand Up @@ -298,7 +298,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
max_length=self.max_length,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class DefaultPromptHandler:
are within the model_max_length.
"""

def __init__(self, model_name: str, model_max_length: int, max_length: int = 100):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def __init__(self, model: str, model_max_length: int, max_length: int = 100):
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.tokenizer.model_max_length = model_max_length
self.model_max_length = model_max_length
self.max_length = max_length
Expand Down
48 changes: 24 additions & 24 deletions integrations/amazon_bedrock/tests/test_amazon_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session):
Test that the to_dict method returns the correct dictionary without aws credentials
"""
generator = AmazonBedrockGenerator(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
max_length=99,
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
Expand All @@ -57,7 +57,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session):
expected_dict = {
"type": "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator",
"init_parameters": {
"model_name": "anthropic.claude-v2",
"model": "anthropic.claude-v2",
"max_length": 99,
},
}
Expand All @@ -74,14 +74,14 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session):
{
"type": "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator",
"init_parameters": {
"model_name": "anthropic.claude-v2",
"model": "anthropic.claude-v2",
"max_length": 99,
},
}
)

assert generator.max_length == 99
assert generator.model_name == "anthropic.claude-v2"
assert generator.model == "anthropic.claude-v2"


@pytest.mark.unit
Expand All @@ -91,7 +91,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):
"""

layer = AmazonBedrockGenerator(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
max_length=99,
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
Expand All @@ -101,7 +101,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):
)

assert layer.max_length == 99
assert layer.model_name == "anthropic.claude-v2"
assert layer.model == "anthropic.claude-v2"

assert layer.prompt_handler is not None
assert layer.prompt_handler.model_max_length == 4096
Expand All @@ -124,7 +124,7 @@ def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_
"""
Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2
"""
layer = AmazonBedrockGenerator(model_name="anthropic.claude-v2", prompt_handler=mock_prompt_handler)
layer = AmazonBedrockGenerator(model="anthropic.claude-v2", prompt_handler=mock_prompt_handler)
assert layer.prompt_handler is not None
assert layer.prompt_handler.model_max_length == 4096

Expand All @@ -136,26 +136,26 @@ def test_constructor_with_model_kwargs(mock_auto_tokenizer, mock_boto3_session):
"""
model_kwargs = {"temperature": 0.7}

layer = AmazonBedrockGenerator(model_name="anthropic.claude-v2", **model_kwargs)
layer = AmazonBedrockGenerator(model="anthropic.claude-v2", **model_kwargs)
assert "temperature" in layer.model_adapter.model_kwargs
assert layer.model_adapter.model_kwargs["temperature"] == 0.7


@pytest.mark.unit
def test_constructor_with_empty_model_name():
def test_constructor_with_empty_model():
"""
Test that the constructor raises an error when the model_name is empty
Test that the constructor raises an error when the model is empty
"""
with pytest.raises(ValueError, match="cannot be None or empty string"):
AmazonBedrockGenerator(model_name="")
AmazonBedrockGenerator(model="")


@pytest.mark.unit
def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session):
"""
Test invoke raises an error if no prompt is provided
"""
layer = AmazonBedrockGenerator(model_name="anthropic.claude-v2")
layer = AmazonBedrockGenerator(model="anthropic.claude-v2")
with pytest.raises(ValueError, match="The model anthropic.claude-v2 requires a valid prompt."):
layer.invoke()

Expand Down Expand Up @@ -239,7 +239,7 @@ def test_supports_for_valid_aws_configuration():
return_value=mock_session,
):
supported = AmazonBedrockGenerator.supports(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
)
args, kwargs = mock_session.client("bedrock").list_foundation_models.call_args
Expand All @@ -254,7 +254,7 @@ def test_supports_raises_on_invalid_aws_profile_name():
mock_boto3_session.side_effect = BotoCoreError()
with pytest.raises(AmazonBedrockConfigurationError, match="Failed to initialize the session"):
AmazonBedrockGenerator.supports(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
aws_profile_name="some_fake_profile",
)

Expand All @@ -270,7 +270,7 @@ def test_supports_for_invalid_bedrock_config():
return_value=mock_session,
), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."):
AmazonBedrockGenerator.supports(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
)

Expand All @@ -286,21 +286,21 @@ def test_supports_for_invalid_bedrock_config_error_on_list_models():
return_value=mock_session,
), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."):
AmazonBedrockGenerator.supports(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
)


@pytest.mark.unit
def test_supports_for_no_aws_params():
supported = AmazonBedrockGenerator.supports(model_name="anthropic.claude-v2")
supported = AmazonBedrockGenerator.supports(model="anthropic.claude-v2")

assert supported is False


@pytest.mark.unit
def test_supports_for_unknown_model():
supported = AmazonBedrockGenerator.supports(model_name="unknown_model", aws_profile_name="some_real_profile")
supported = AmazonBedrockGenerator.supports(model="unknown_model", aws_profile_name="some_real_profile")

assert supported is False

Expand All @@ -318,7 +318,7 @@ def test_supports_with_stream_true_for_model_that_supports_streaming():
return_value=mock_session,
):
supported = AmazonBedrockGenerator.supports(
model_name="anthropic.claude-v2",
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
stream=True,
)
Expand All @@ -342,15 +342,15 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming():
match="The model ai21.j2-mid-v1 doesn't support streaming.",
):
AmazonBedrockGenerator.supports(
model_name="ai21.j2-mid-v1",
model="ai21.j2-mid-v1",
aws_profile_name="some_real_profile",
stream=True,
)


@pytest.mark.unit
@pytest.mark.parametrize(
"model_name, expected_model_adapter",
"model, expected_model_adapter",
[
("anthropic.claude-v1", AnthropicClaudeAdapter),
("anthropic.claude-v2", AnthropicClaudeAdapter),
Expand All @@ -372,11 +372,11 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming():
("unknown_model", None),
],
)
def test_get_model_adapter(model_name: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]):
def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]):
"""
Test that the correct model adapter is returned for a given model_name
Test that the correct model adapter is returned for a given model
"""
model_adapter = AmazonBedrockGenerator.get_model_adapter(model_name=model_name)
model_adapter = AmazonBedrockGenerator.get_model_adapter(model=model)
assert model_adapter == expected_model_adapter


Expand Down