Skip to content

Commit

Permalink
fix tests after llama adapter rename
Browse files Browse the repository at this point in the history
  • Loading branch information
tstadel committed Jun 14, 2024
1 parent 0e9f196 commit c16f187
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
BedrockModelAdapter,
CohereCommandAdapter,
CohereCommandRAdapter,
MetaLlama2ChatAdapter,
MetaLlamaAdapter,
MistralAdapter,
)

Expand Down Expand Up @@ -214,12 +214,12 @@ def test_long_prompt_is_truncated(mock_boto3_session):
("amazon.titan-text-express-v1", AmazonTitanAdapter),
("amazon.titan-text-agile-v1", AmazonTitanAdapter),
("amazon.titan-text-lightning-v8", AmazonTitanAdapter), # artificial
("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter),
("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter),
("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial
("meta.llama3-8b-instruct-v1:0", MetaLlama2ChatAdapter),
("meta.llama3-70b-instruct-v1:0", MetaLlama2ChatAdapter),
("meta.llama3-130b-instruct-v5:9", MetaLlama2ChatAdapter), # artificial
("meta.llama2-13b-chat-v1", MetaLlamaAdapter),
("meta.llama2-70b-chat-v1", MetaLlamaAdapter),
("meta.llama2-130b-v5", MetaLlamaAdapter), # artificial
("meta.llama3-8b-instruct-v1:0", MetaLlamaAdapter),
("meta.llama3-70b-instruct-v1:0", MetaLlamaAdapter),
("meta.llama3-130b-instruct-v5:9", MetaLlamaAdapter), # artificial
("mistral.mistral-7b-instruct-v0:2", MistralAdapter),
("mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter),
("mistral.mistral-large-2402-v1:0", MistralAdapter),
Expand Down Expand Up @@ -1290,9 +1290,9 @@ def test_get_stream_responses_empty(self) -> None:
stream_handler_mock.assert_not_called()


class TestMetaLlama2ChatAdapter:
class TestMetaLlamaAdapter:
def test_prepare_body_with_default_params(self) -> None:
layer = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99)
layer = MetaLlamaAdapter(model_kwargs={}, max_length=99)
prompt = "Hello, how are you?"
expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 99}

Expand All @@ -1301,7 +1301,7 @@ def test_prepare_body_with_default_params(self) -> None:
assert body == expected_body

def test_prepare_body_with_custom_inference_params(self) -> None:
layer = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99)
layer = MetaLlamaAdapter(model_kwargs={}, max_length=99)
prompt = "Hello, how are you?"
expected_body = {
"prompt": "Hello, how are you?",
Expand All @@ -1321,7 +1321,7 @@ def test_prepare_body_with_custom_inference_params(self) -> None:
assert body == expected_body

def test_prepare_body_with_model_kwargs(self) -> None:
layer = MetaLlama2ChatAdapter(
layer = MetaLlamaAdapter(
model_kwargs={
"temperature": 0.7,
"top_p": 0.8,
Expand All @@ -1343,7 +1343,7 @@ def test_prepare_body_with_model_kwargs(self) -> None:
assert body == expected_body

def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None:
layer = MetaLlama2ChatAdapter(
layer = MetaLlamaAdapter(
model_kwargs={
"temperature": 0.6,
"top_p": 0.7,
Expand All @@ -1365,13 +1365,13 @@ def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> Non
assert body == expected_body

def test_get_responses(self) -> None:
adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99)
adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99)
response_body = {"generation": "This is a single response."}
expected_responses = ["This is a single response."]
assert adapter.get_responses(response_body) == expected_responses

def test_get_responses_leading_whitespace(self) -> None:
adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99)
adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99)
response_body = {"generation": "\n\t This is a single response."}
expected_responses = ["This is a single response."]
assert adapter.get_responses(response_body) == expected_responses
Expand All @@ -1390,7 +1390,7 @@ def test_get_stream_responses(self) -> None:

stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received

adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99)
adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99)
expected_responses = ["This is a single response."]
assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses

Expand All @@ -1412,7 +1412,7 @@ def test_get_stream_responses_empty(self) -> None:

stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received

adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99)
adapter = MetaLlamaAdapter(model_kwargs={}, max_length=99)
expected_responses = [""]
assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses

Expand Down

0 comments on commit c16f187

Please sign in to comment.