diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/errors.py b/integrations/amazon_bedrock/src/amazon_bedrock_haystack/errors.py index 21dd5b0bc..aa8a3f6e4 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/errors.py +++ b/integrations/amazon_bedrock/src/amazon_bedrock_haystack/errors.py @@ -11,7 +11,8 @@ class AmazonBedrockError(Exception): """ def __init__( - self, message: Optional[str] = None, + self, + message: Optional[str] = None, ): super().__init__() if message: diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py b/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py index b61751506..fc96b1b03 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py +++ b/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py @@ -70,15 +70,15 @@ class AmazonBedrockGenerator: } def __init__( - self, - model_name_or_path: str, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - aws_region_name: Optional[str] = None, - aws_profile_name: Optional[str] = None, - max_length: Optional[int] = 100, - **kwargs, + self, + model_name_or_path: str, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + max_length: Optional[int] = 100, + **kwargs, ): if model_name_or_path is None or len(model_name_or_path) == 0: msg = "model_name_or_path cannot be None or empty string" @@ -96,8 +96,10 @@ def __init__( ) self.client = session.client("bedrock-runtime") except Exception as exception: - msg = ("Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " - "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration") + msg = ( + "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " + "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" + ) raise AmazonBedrockConfigurationError(msg) from exception model_input_kwargs = kwargs @@ -115,23 +117,19 @@ def __init__( max_length=self.max_length or 100, ) - model_apapter_cls = self.get_model_adapter( - model_name_or_path=model_name_or_path - ) + model_apapter_cls = self.get_model_adapter(model_name_or_path=model_name_or_path) if not model_apapter_cls: msg = f"AmazonBedrockGenerator doesn't support the model {model_name_or_path}." raise AmazonBedrockConfigurationError(msg) - self.model_adapter = model_apapter_cls( - model_kwargs=model_input_kwargs, max_length=self.max_length - ) + self.model_adapter = model_apapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length) - def _ensure_token_limit( - self, prompt: Union[str, List[Dict[str, str]]] - ) -> Union[str, List[Dict[str, str]]]: + def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: # the prompt for this model will be of the type str if isinstance(prompt, List): - msg = ("AmazonBedrockGenerator only supports a string as a prompt, " - "while currently, the prompt is of type List.") + msg = ( + "AmazonBedrockGenerator only supports a string as a prompt, " + "while currently, the prompt is of type List." + ) raise ValueError(msg) resize_info = self.prompt_handler(prompt) @@ -156,32 +154,29 @@ def supports(cls, model_name_or_path, **kwargs): try: session = cls.get_aws_session(**kwargs) bedrock = session.client("bedrock") - foundation_models_response = bedrock.list_foundation_models( - byOutputModality="TEXT" - ) - available_model_ids = [ - entry["modelId"] - for entry in foundation_models_response.get("modelSummaries", []) - ] + foundation_models_response = bedrock.list_foundation_models(byOutputModality="TEXT") + available_model_ids = [entry["modelId"] for entry in foundation_models_response.get("modelSummaries", [])] model_ids_supporting_streaming = [ entry["modelId"] for entry in foundation_models_response.get("modelSummaries", []) if entry.get("responseStreamingSupported", False) ] except AWSConfigurationError as exception: - raise AmazonBedrockConfigurationError( - message=exception.message - ) from exception + raise AmazonBedrockConfigurationError(message=exception.message) from exception except Exception as exception: - msg = ("Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " - "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration") + msg = ( + "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " + "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" + ) raise AmazonBedrockConfigurationError(msg) from exception model_available = model_name_or_path in available_model_ids if not model_available: - msg = (f"The model {model_name_or_path} is not available in Amazon Bedrock. " - f"Make sure the model you want to use is available in the configured AWS region and " - f"you have access.") + msg = ( + f"The model {model_name_or_path} is not available in Amazon Bedrock. " + f"Make sure the model you want to use is available in the configured AWS region and " + f"you have access." + ) raise AmazonBedrockConfigurationError(msg) stream: bool = kwargs.get("stream", False) @@ -195,13 +190,13 @@ def supports(cls, model_name_or_path, **kwargs): def invoke(self, *args, **kwargs): kwargs = kwargs.copy() prompt: str = kwargs.pop("prompt", None) - stream: bool = kwargs.get( - "stream", self.model_adapter.model_kwargs.get("stream", False) - ) + stream: bool = kwargs.get("stream", self.model_adapter.model_kwargs.get("stream", False)) if not prompt or not isinstance(prompt, (str, list)): - msg = (f"The model {self.model_name_or_path} requires a valid prompt, but currently, it has no prompt. " - f"Make sure to provide a prompt in the format that the model expects.") + msg = ( + f"The model {self.model_name_or_path} requires a valid prompt, but currently, it has no prompt. " + f"Make sure to provide a prompt in the format that the model expects." + ) raise ValueError(msg) body = self.model_adapter.prepare_body(prompt=prompt, **kwargs) @@ -216,13 +211,9 @@ def invoke(self, *args, **kwargs): response_stream = response["body"] handler: TokenStreamingHandler = kwargs.get( "stream_handler", - self.model_adapter.model_kwargs.get( - "stream_handler", DefaultTokenStreamingHandler() - ), - ) - responses = self.model_adapter.get_stream_responses( - stream=response_stream, stream_handler=handler + self.model_adapter.model_kwargs.get("stream_handler", DefaultTokenStreamingHandler()), ) + responses = self.model_adapter.get_stream_responses(stream=response_stream, stream_handler=handler) else: response = self.client.invoke_model( body=json.dumps(body), @@ -231,13 +222,13 @@ def invoke(self, *args, **kwargs): contentType="application/json", ) response_body = json.loads(response.get("body").read().decode("utf-8")) - responses = self.model_adapter.get_responses( - response_body=response_body - ) + responses = self.model_adapter.get_responses(response_body=response_body) except ClientError as exception: - msg = (f"Could not connect to Amazon Bedrock model {self.model_name_or_path}. " - f"Make sure your AWS environment is configured correctly, " - f"the model is available in the configured AWS region, and you have access.") + msg = ( + f"Could not connect to Amazon Bedrock model {self.model_name_or_path}. " + f"Make sure your AWS environment is configured correctly, " + f"the model is available in the configured AWS region, and you have access." + ) raise AmazonBedrockInferenceError(msg) from exception return responses @@ -247,9 +238,7 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): pass @classmethod - def get_model_adapter( - cls, model_name_or_path: str - ) -> Optional[Type[BedrockModelAdapter]]: + def get_model_adapter(cls, model_name_or_path: str) -> Optional[Type[BedrockModelAdapter]]: for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): if re.fullmatch(pattern, model_name_or_path): return adapter @@ -267,13 +256,13 @@ def aws_configured(cls, **kwargs) -> bool: @classmethod def get_aws_session( - cls, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - aws_region_name: Optional[str] = None, - aws_profile_name: Optional[str] = None, - **kwargs, + cls, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + **kwargs, ): """ Creates an AWS Session with the given parameters. @@ -298,8 +287,6 @@ def get_aws_session( profile_name=aws_profile_name, ) except BotoCoreError as e: - provided_aws_config = { - k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS - } + provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" raise AWSConfigurationError(msg) from e diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_adapters.py b/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_adapters.py index 54ebda86c..bec172867 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_adapters.py +++ b/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_adapters.py @@ -24,9 +24,7 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[str]: responses = [completion.lstrip() for completion in completions] return responses - def get_stream_responses( - self, stream, stream_handler: TokenStreamingHandler - ) -> List[str]: + def get_stream_responses(self, stream, stream_handler: TokenStreamingHandler) -> List[str]: tokens: List[str] = [] for event in stream: chunk = event.get("chunk") @@ -37,9 +35,7 @@ def get_stream_responses( responses = ["".join(tokens).lstrip()] return responses - def _get_params( - self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any] - ) -> Dict[str, Any]: + def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]: """ Merges the default params with the inference kwargs and model kwargs. @@ -54,9 +50,7 @@ def _get_params( } @abstractmethod - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: """Extracts the responses from the Amazon Bedrock response.""" @abstractmethod @@ -82,9 +76,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params} return body - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: return [response_body["completion"]] def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: @@ -114,9 +106,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: body = {"prompt": prompt, **params} return body - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: responses = [generation["text"] for generation in response_body["generations"]] return responses @@ -145,12 +135,8 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: body = {"prompt": prompt, **params} return body - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: - responses = [ - completion["data"]["text"] for completion in response_body["completions"] - ] + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + responses = [completion["data"]["text"] for completion in response_body["completions"]] return responses def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: @@ -175,9 +161,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: body = {"inputText": prompt, "textGenerationConfig": params} return body - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: responses = [result["outputText"] for result in response_body["results"]] return responses @@ -201,9 +185,7 @@ def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: body = {"prompt": prompt, **params} return body - def _extract_completions_from_response( - self, response_body: Dict[str, Any] - ) -> List[str]: + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: return [response_body["generation"]] def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py b/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py index 2532c13eb..26bc966b8 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py +++ b/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py @@ -10,9 +10,7 @@ class DefaultPromptHandler: are within the model_max_length. """ - def __init__( - self, model_name_or_path: str, model_max_length: int, max_length: int = 100 - ): + def __init__(self, model_name_or_path: str, model_max_length: int, max_length: int = 100): self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.tokenizer.model_max_length = model_max_length self.model_max_length = model_max_length @@ -40,9 +38,7 @@ def __call__(self, prompt: str, **kwargs) -> Dict[str, Union[str, int]]: resized_prompt = self.tokenizer.convert_tokens_to_string( tokenized_prompt[: self.model_max_length - self.max_length] ) - new_prompt_length = len( - tokenized_prompt[: self.model_max_length - self.max_length] - ) + new_prompt_length = len(tokenized_prompt[: self.model_max_length - self.max_length]) return { "resized_prompt": resized_prompt, diff --git a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py index 26fa24e5c..cecb28578 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py @@ -26,7 +26,7 @@ def mock_boto3_session(): @pytest.fixture def mock_prompt_handler(): with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock_handlers.DefaultPromptHandler" + "amazon_bedrock_haystack.generators.amazon_bedrock_handlers.DefaultPromptHandler" ) as mock_prompt_handler: yield mock_prompt_handler @@ -67,15 +67,11 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session): @pytest.mark.unit -def test_constructor_prompt_handler_initialized( - mock_auto_tokenizer, mock_boto3_session -): +def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session): """ Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2 """ - layer = AmazonBedrockGenerator( - model_name_or_path="anthropic.claude-v2", prompt_handler=mock_prompt_handler - ) + layer = AmazonBedrockGenerator(model_name_or_path="anthropic.claude-v2", prompt_handler=mock_prompt_handler) assert layer.prompt_handler is not None assert layer.prompt_handler.model_max_length == 4096 @@ -87,9 +83,7 @@ def test_constructor_with_model_kwargs(mock_auto_tokenizer, mock_boto3_session): """ model_kwargs = {"temperature": 0.7} - layer = AmazonBedrockGenerator( - model_name_or_path="anthropic.claude-v2", **model_kwargs - ) + layer = AmazonBedrockGenerator(model_name_or_path="anthropic.claude-v2", **model_kwargs) assert "temperature" in layer.model_adapter.model_kwargs assert layer.model_adapter.model_kwargs["temperature"] == 0.7 @@ -109,9 +103,7 @@ def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session): Test invoke raises an error if no prompt is provided """ layer = AmazonBedrockGenerator(model_name_or_path="anthropic.claude-v2") - with pytest.raises( - ValueError, match="The model anthropic.claude-v2 requires a valid prompt." - ): + with pytest.raises(ValueError, match="The model anthropic.claude-v2 requires a valid prompt."): layer.invoke() @@ -134,9 +126,7 @@ def test_short_prompt_is_not_truncated(mock_boto3_session): max_length_generated_text = 3 total_model_max_length = 10 - with patch( - "transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer - ): + with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): layer = AmazonBedrockGenerator( "anthropic.claude-v2", max_length=max_length_generated_text, @@ -171,9 +161,7 @@ def test_long_prompt_is_truncated(mock_boto3_session): max_length_generated_text = 3 total_model_max_length = 10 - with patch( - "transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer - ): + with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): layer = AmazonBedrockGenerator( "anthropic.claude-v2", max_length=max_length_generated_text, @@ -194,8 +182,8 @@ def test_supports_for_valid_aws_configuration(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.AmazonBedrockGenerator.get_aws_session", - return_value=mock_session, + "amazon_bedrock_haystack.generators.AmazonBedrockGenerator.get_aws_session", + return_value=mock_session, ): supported = AmazonBedrockGenerator.supports( model_name_or_path="anthropic.claude-v2", @@ -211,9 +199,7 @@ def test_supports_for_valid_aws_configuration(): def test_supports_raises_on_invalid_aws_profile_name(): with patch("boto3.Session") as mock_boto3_session: mock_boto3_session.side_effect = BotoCoreError() - with pytest.raises( - AmazonBedrockConfigurationError, match="Failed to initialize the session" - ): + with pytest.raises(AmazonBedrockConfigurationError, match="Failed to initialize the session"): AmazonBedrockGenerator.supports( model_name_or_path="anthropic.claude-v2", aws_profile_name="some_fake_profile", @@ -227,11 +213,9 @@ def test_supports_for_invalid_bedrock_config(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.AmazonBedrockGenerator.get_aws_session", - return_value=mock_session, - ), pytest.raises( - AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock." - ): + "amazon_bedrock_haystack.generators.AmazonBedrockGenerator.get_aws_session", + return_value=mock_session, + ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): AmazonBedrockGenerator.supports( model_name_or_path="anthropic.claude-v2", aws_profile_name="some_real_profile", @@ -245,11 +229,9 @@ def test_supports_for_invalid_bedrock_config_error_on_list_models(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.AmazonBedrockGenerator.get_aws_session", - return_value=mock_session, - ), pytest.raises( - AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock." - ): + "amazon_bedrock_haystack.generators.AmazonBedrockGenerator.get_aws_session", + return_value=mock_session, + ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): AmazonBedrockGenerator.supports( model_name_or_path="anthropic.claude-v2", aws_profile_name="some_real_profile", @@ -258,9 +240,7 @@ def test_supports_for_invalid_bedrock_config_error_on_list_models(): @pytest.mark.unit def test_supports_for_no_aws_params(): - supported = AmazonBedrockGenerator.supports( - model_name_or_path="anthropic.claude-v2" - ) + supported = AmazonBedrockGenerator.supports(model_name_or_path="anthropic.claude-v2") assert supported is False @@ -278,15 +258,13 @@ def test_supports_for_unknown_model(): def test_supports_with_stream_true_for_model_that_supports_streaming(): mock_session = MagicMock() mock_session.client("bedrock").list_foundation_models.return_value = { - "modelSummaries": [ - {"modelId": "anthropic.claude-v2", "responseStreamingSupported": True} - ] + "modelSummaries": [{"modelId": "anthropic.claude-v2", "responseStreamingSupported": True}] } # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.AmazonBedrockGenerator.get_aws_session", - return_value=mock_session, + "amazon_bedrock_haystack.generators.AmazonBedrockGenerator.get_aws_session", + return_value=mock_session, ): supported = AmazonBedrockGenerator.supports( model_name_or_path="anthropic.claude-v2", @@ -301,15 +279,13 @@ def test_supports_with_stream_true_for_model_that_supports_streaming(): def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): mock_session = MagicMock() mock_session.client("bedrock").list_foundation_models.return_value = { - "modelSummaries": [ - {"modelId": "ai21.j2-mid-v1", "responseStreamingSupported": False} - ] + "modelSummaries": [{"modelId": "ai21.j2-mid-v1", "responseStreamingSupported": False}] } # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.AmazonBedrockGenerator.get_aws_session", - return_value=mock_session, + "amazon_bedrock_haystack.generators.AmazonBedrockGenerator.get_aws_session", + return_value=mock_session, ), pytest.raises( AmazonBedrockConfigurationError, match="The model ai21.j2-mid-v1 doesn't support streaming.", @@ -345,15 +321,11 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): ("unknown_model", None), ], ) -def test_get_model_adapter( - model_name_or_path: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]] -): +def test_get_model_adapter(model_name_or_path: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): """ Test that the correct model adapter is returned for a given model_name_or_path """ - model_adapter = AmazonBedrockGenerator.get_model_adapter( - model_name_or_path=model_name_or_path - ) + model_adapter = AmazonBedrockGenerator.get_model_adapter(model_name_or_path=model_name_or_path) assert model_adapter == expected_model_adapter @@ -442,9 +414,7 @@ def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> Non "top_k": 5, } - body = layer.prepare_body( - prompt, temperature=0.7, top_p=0.8, top_k=5, max_tokens_to_sample=50 - ) + body = layer.prepare_body(prompt, temperature=0.7, top_p=0.8, top_k=5, max_tokens_to_sample=50) assert body == expected_body @@ -472,16 +442,11 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"completion": " response."}'}}, ] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_has_calls( [ @@ -499,16 +464,11 @@ def test_get_stream_responses_empty(self) -> None: stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_not_called() @@ -675,23 +635,14 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"text": " a"}'}}, {"chunk": {"bytes": b'{"text": " single"}'}}, {"chunk": {"bytes": b'{"text": " response."}'}}, - { - "chunk": { - "bytes": b'{"finish_reason": "MAX_TOKENS", "is_finished": true}' - } - }, + {"chunk": {"bytes": b'{"finish_reason": "MAX_TOKENS", "is_finished": true}'}}, ] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_has_calls( [ @@ -700,9 +651,7 @@ def test_get_stream_responses(self) -> None: call(" a", event_data={"text": " a"}), call(" single", event_data={"text": " single"}), call(" response.", event_data={"text": " response."}), - call( - "", event_data={"finish_reason": "MAX_TOKENS", "is_finished": True} - ), + call("", event_data={"finish_reason": "MAX_TOKENS", "is_finished": True}), ] ) @@ -712,16 +661,11 @@ def test_get_stream_responses_empty(self) -> None: stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_not_called() @@ -841,17 +785,13 @@ def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> Non def test_get_responses(self) -> None: adapter = AI21LabsJurassic2Adapter(model_kwargs={}, max_length=99) - response_body = { - "completions": [{"data": {"text": "This is a single response."}}] - } + response_body = {"completions": [{"data": {"text": "This is a single response."}}]} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses def test_get_responses_leading_whitespace(self) -> None: adapter = AI21LabsJurassic2Adapter(model_kwargs={}, max_length=99) - response_body = { - "completions": [{"data": {"text": "\n\t This is a single response."}}] - } + response_body = {"completions": [{"data": {"text": "\n\t This is a single response."}}]} expected_responses = ["This is a single response."] assert adapter.get_responses(response_body) == expected_responses @@ -996,16 +936,11 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"outputText": " response."}'}}, ] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_has_calls( [ @@ -1023,16 +958,11 @@ def test_get_stream_responses_empty(self) -> None: stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_not_called() @@ -1135,16 +1065,11 @@ def test_get_stream_responses(self) -> None: {"chunk": {"bytes": b'{"generation": " response."}'}}, ] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) expected_responses = ["This is a single response."] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_has_calls( [ @@ -1162,15 +1087,10 @@ def test_get_stream_responses_empty(self) -> None: stream_mock.__iter__.return_value = [] - stream_handler_mock.side_effect = ( - lambda token_received, **kwargs: token_received - ) + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) expected_responses = [""] - assert ( - adapter.get_stream_responses(stream_mock, stream_handler_mock) - == expected_responses - ) + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses stream_handler_mock.assert_not_called()