Skip to content

Commit

Permalink
Pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Feb 6, 2024
1 parent d3c8157 commit 3e842c5
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,7 @@ class AmazonTitanAdapter(BedrockModelAdapter):
"""

def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
default_params = {
"maxTokenCount": self.max_length,
"stopSequences": None,
"temperature": None,
"topP": None,
}
default_params = {"maxTokenCount": self.max_length, "stopSequences": None, "temperature": None, "topP": None}
params = self._get_params(inference_kwargs, default_params)

body = {"inputText": prompt, "textGenerationConfig": params}
Expand All @@ -175,11 +170,7 @@ class MetaLlama2ChatAdapter(BedrockModelAdapter):
"""

def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
default_params = {
"max_gen_len": self.max_length,
"temperature": None,
"top_p": None,
}
default_params = {"max_gen_len": self.max_length, "temperature": None, "top_p": None}
params = self._get_params(inference_kwargs, default_params)

body = {"prompt": prompt, **params}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ class AmazonBedrockError(Exception):
`AmazonBedrockError.message` will exist and have the expected content.
"""

def __init__(
self,
message: Optional[str] = None,
):
def __init__(self, message: Optional[str] = None):
super().__init__()
if message:
self.message = message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,8 @@
CohereCommandAdapter,
MetaLlama2ChatAdapter,
)
from .errors import (
AmazonBedrockConfigurationError,
AmazonBedrockInferenceError,
AWSConfigurationError,
)
from .handlers import (
DefaultPromptHandler,
DefaultTokenStreamingHandler,
TokenStreamingHandler,
)
from .errors import AmazonBedrockConfigurationError, AmazonBedrockInferenceError, AWSConfigurationError
from .handlers import DefaultPromptHandler, DefaultTokenStreamingHandler, TokenStreamingHandler

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,9 +104,7 @@ def __init__(
# It is hard to determine which tokenizer to use for the SageMaker model
# so we use GPT2 tokenizer which will likely provide good token count approximation
self.prompt_handler = DefaultPromptHandler(
model="gpt2",
model_max_length=model_max_length,
max_length=self.max_length or 100,
model="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100
)

model_adapter_cls = self.get_model_adapter(model=model)
Expand Down Expand Up @@ -203,10 +193,7 @@ def invoke(self, *args, **kwargs):
try:
if stream:
response = self.client.invoke_model_with_response_stream(
body=json.dumps(body),
modelId=self.model,
accept="application/json",
contentType="application/json",
body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json"
)
response_stream = response["body"]
handler: TokenStreamingHandler = kwargs.get(
Expand All @@ -216,10 +203,7 @@ def invoke(self, *args, **kwargs):
responses = self.model_adapter.get_stream_responses(stream=response_stream, stream_handler=handler)
else:
response = self.client.invoke_model(
body=json.dumps(body),
modelId=self.model,
accept="application/json",
contentType="application/json",
body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json"
)
response_body = json.loads(response.get("body").read().decode("utf-8"))
responses = self.model_adapter.get_responses(response_body=response_body)
Expand Down Expand Up @@ -296,11 +280,7 @@ def to_dict(self) -> Dict[str, Any]:
Serialize this component to a dictionary.
:return: The serialized component as a dictionary.
"""
return default_to_dict(
self,
model=self.model,
max_length=self.max_length,
)
return default_to_dict(self, model=self.model, max_length=self.max_length)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockGenerator":
Expand Down
127 changes: 23 additions & 104 deletions integrations/amazon_bedrock/tests/test_amazon_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session):

expected_dict = {
"type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator",
"init_parameters": {
"model": "anthropic.claude-v2",
"max_length": 99,
},
"init_parameters": {"model": "anthropic.claude-v2", "max_length": 99},
}

assert generator.to_dict() == expected_dict
Expand All @@ -73,10 +70,7 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session):
generator = AmazonBedrockGenerator.from_dict(
{
"type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator",
"init_parameters": {
"model": "anthropic.claude-v2",
"max_length": 99,
},
"init_parameters": {"model": "anthropic.claude-v2", "max_length": 99},
}
)

Expand Down Expand Up @@ -181,9 +175,7 @@ def test_short_prompt_is_not_truncated(mock_boto3_session):

with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer):
layer = AmazonBedrockGenerator(
"anthropic.claude-v2",
max_length=max_length_generated_text,
model_max_length=total_model_max_length,
"anthropic.claude-v2", max_length=max_length_generated_text, model_max_length=total_model_max_length
)
prompt_after_resize = layer._ensure_token_limit(mock_prompt_text)

Expand Down Expand Up @@ -216,9 +208,7 @@ def test_long_prompt_is_truncated(mock_boto3_session):

with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer):
layer = AmazonBedrockGenerator(
"anthropic.claude-v2",
max_length=max_length_generated_text,
model_max_length=total_model_max_length,
"anthropic.claude-v2", max_length=max_length_generated_text, model_max_length=total_model_max_length
)
prompt_after_resize = layer._ensure_token_limit(long_prompt_text)

Expand All @@ -238,10 +228,7 @@ def test_supports_for_valid_aws_configuration():
"haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session",
return_value=mock_session,
):
supported = AmazonBedrockGenerator.supports(
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
)
supported = AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_real_profile")
args, kwargs = mock_session.client("bedrock").list_foundation_models.call_args
assert kwargs["byOutputModality"] == "TEXT"

Expand All @@ -253,10 +240,7 @@ def test_supports_raises_on_invalid_aws_profile_name():
with patch("boto3.Session") as mock_boto3_session:
mock_boto3_session.side_effect = BotoCoreError()
with pytest.raises(AmazonBedrockConfigurationError, match="Failed to initialize the session"):
AmazonBedrockGenerator.supports(
model="anthropic.claude-v2",
aws_profile_name="some_fake_profile",
)
AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_fake_profile")


@pytest.mark.unit
Expand All @@ -269,10 +253,7 @@ def test_supports_for_invalid_bedrock_config():
"haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session",
return_value=mock_session,
), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."):
AmazonBedrockGenerator.supports(
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
)
AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_real_profile")


@pytest.mark.unit
Expand All @@ -285,10 +266,7 @@ def test_supports_for_invalid_bedrock_config_error_on_list_models():
"haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session",
return_value=mock_session,
), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."):
AmazonBedrockGenerator.supports(
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
)
AmazonBedrockGenerator.supports(model="anthropic.claude-v2", aws_profile_name="some_real_profile")


@pytest.mark.unit
Expand Down Expand Up @@ -318,9 +296,7 @@ def test_supports_with_stream_true_for_model_that_supports_streaming():
return_value=mock_session,
):
supported = AmazonBedrockGenerator.supports(
model="anthropic.claude-v2",
aws_profile_name="some_real_profile",
stream=True,
model="anthropic.claude-v2", aws_profile_name="some_real_profile", stream=True
)

assert supported
Expand All @@ -337,15 +313,8 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming():
with patch(
"haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session",
return_value=mock_session,
), pytest.raises(
AmazonBedrockConfigurationError,
match="The model ai21.j2-mid-v1 doesn't support streaming.",
):
AmazonBedrockGenerator.supports(
model="ai21.j2-mid-v1",
aws_profile_name="some_real_profile",
stream=True,
)
), pytest.raises(AmazonBedrockConfigurationError, match="The model ai21.j2-mid-v1 doesn't support streaming."):
AmazonBedrockGenerator.supports(model="ai21.j2-mid-v1", aws_profile_name="some_real_profile", stream=True)


@pytest.mark.unit
Expand Down Expand Up @@ -665,15 +634,9 @@ def test_get_responses_leading_whitespace(self) -> None:
def test_get_responses_multiple_responses(self) -> None:
adapter = CohereCommandAdapter(model_kwargs={}, max_length=99)
response_body = {
"generations": [
{"text": "This is a single response."},
{"text": "This is a second response."},
]
"generations": [{"text": "This is a single response."}, {"text": "This is a second response."}]
}
expected_responses = [
"This is a single response.",
"This is a second response.",
]
expected_responses = ["This is a single response.", "This is a second response."]
assert adapter.get_responses(response_body) == expected_responses

def test_get_stream_responses(self) -> None:
Expand Down Expand Up @@ -854,21 +817,15 @@ def test_get_responses_multiple_responses(self) -> None:
{"data": {"text": "This is a second response."}},
]
}
expected_responses = [
"This is a single response.",
"This is a second response.",
]
expected_responses = ["This is a single response.", "This is a second response."]
assert adapter.get_responses(response_body) == expected_responses


class TestAmazonTitanAdapter:
def test_prepare_body_with_default_params(self) -> None:
layer = AmazonTitanAdapter(model_kwargs={}, max_length=99)
prompt = "Hello, how are you?"
expected_body = {
"inputText": "Hello, how are you?",
"textGenerationConfig": {"maxTokenCount": 99},
}
expected_body = {"inputText": "Hello, how are you?", "textGenerationConfig": {"maxTokenCount": 99}}

body = layer.prepare_body(prompt)

Expand Down Expand Up @@ -964,15 +921,9 @@ def test_get_responses_leading_whitespace(self) -> None:
def test_get_responses_multiple_responses(self) -> None:
adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99)
response_body = {
"results": [
{"outputText": "This is a single response."},
{"outputText": "This is a second response."},
]
"results": [{"outputText": "This is a single response."}, {"outputText": "This is a second response."}]
}
expected_responses = [
"This is a single response.",
"This is a second response.",
]
expected_responses = ["This is a single response.", "This is a second response."]
assert adapter.get_responses(response_body) == expected_responses

def test_get_stream_responses(self) -> None:
Expand Down Expand Up @@ -1031,62 +982,30 @@ def test_prepare_body_with_default_params(self) -> None:
def test_prepare_body_with_custom_inference_params(self) -> None:
layer = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99)
prompt = "Hello, how are you?"
expected_body = {
"prompt": "Hello, how are you?",
"max_gen_len": 50,
"temperature": 0.7,
"top_p": 0.8,
}
expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.8}

body = layer.prepare_body(
prompt,
temperature=0.7,
top_p=0.8,
max_gen_len=50,
unknown_arg="unknown_value",
)
body = layer.prepare_body(prompt, temperature=0.7, top_p=0.8, max_gen_len=50, unknown_arg="unknown_value")

assert body == expected_body

def test_prepare_body_with_model_kwargs(self) -> None:
layer = MetaLlama2ChatAdapter(
model_kwargs={
"temperature": 0.7,
"top_p": 0.8,
"max_gen_len": 50,
"unknown_arg": "unknown_value",
},
model_kwargs={"temperature": 0.7, "top_p": 0.8, "max_gen_len": 50, "unknown_arg": "unknown_value"},
max_length=99,
)
prompt = "Hello, how are you?"
expected_body = {
"prompt": "Hello, how are you?",
"max_gen_len": 50,
"temperature": 0.7,
"top_p": 0.8,
}
expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.8}

body = layer.prepare_body(prompt)

assert body == expected_body

def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None:
layer = MetaLlama2ChatAdapter(
model_kwargs={
"temperature": 0.6,
"top_p": 0.7,
"top_k": 4,
"max_gen_len": 49,
},
max_length=99,
model_kwargs={"temperature": 0.6, "top_p": 0.7, "top_k": 4, "max_gen_len": 49}, max_length=99
)
prompt = "Hello, how are you?"
expected_body = {
"prompt": "Hello, how are you?",
"max_gen_len": 50,
"temperature": 0.7,
"top_p": 0.7,
}
expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.7}

body = layer.prepare_body(prompt, temperature=0.7, max_gen_len=50)

Expand Down
12 changes: 2 additions & 10 deletions integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,28 +194,20 @@ def test_prepare_body_with_custom_inference_params(self, mock_auto_tokenizer) ->
}

body = layer.prepare_body(
[ChatMessage.from_user(prompt)],
top_p=0.8,
top_k=5,
max_tokens_to_sample=69,
stop_sequences=["CUSTOM_STOP"],
[ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69, stop_sequences=["CUSTOM_STOP"]
)

assert body == expected_body


class TestMetaLlama2ChatAdapter:

@pytest.mark.integration
def test_prepare_body_with_default_params(self) -> None:
# leave this test as integration because we really need only tokenizer from HF
# that way we can ensure prompt chat message formatting
layer = MetaLlama2ChatAdapter(generation_kwargs={})
prompt = "Hello, how are you?"
expected_body = {
"prompt": "<s>[INST] Hello, how are you? [/INST]",
"max_gen_len": 512,
}
expected_body = {"prompt": "<s>[INST] Hello, how are you? [/INST]", "max_gen_len": 512}

body = layer.prepare_body([ChatMessage.from_user(prompt)])

Expand Down

0 comments on commit 3e842c5

Please sign in to comment.