From 11b9b2f0a608cdeb61efaa30e32058f50f3ab0fa Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 7 Feb 2024 22:43:52 +0100 Subject: [PATCH] PR feedback David --- .../amazon_bedrock/chat/adapters.py | 30 ++++++++++++------- .../amazon_bedrock/chat/chat_generator.py | 4 ++- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 37c490f6b..a4eefe321 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List +from botocore.eventstream import EventStream from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from transformers import AutoTokenizer, PreTrainedTokenizer @@ -25,9 +26,9 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """Extracts the responses from the Amazon Bedrock response.""" - return self._extract_messages_from_response(response_body) + return self._extract_messages_from_response(self.response_body_message_key(), response_body) - def get_stream_responses(self, stream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]: + def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]: tokens: List[str] = [] for event in stream: chunk = event.get("chunk") @@ -43,7 +44,8 @@ def get_stream_responses(self, stream, stream_handler: Callable[[StreamingChunk] responses = ["".join(tokens).lstrip()] return responses - def _update_params(self, target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> None: + @staticmethod + def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> None: """ Updates target_dict with values from updates_dict. Merges lists instead of overriding them. @@ -62,6 +64,10 @@ def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str """ Merges params from inference_kwargs with the default params and self.generation_kwargs. Uses a helper function to merge lists or override values as necessary. + + :param inference_kwargs: The inference kwargs to merge. + :param default_params: The default params to start with. + :return: The merged params. """ # Start with a copy of default_params kwargs = default_params.copy() @@ -95,9 +101,13 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: :return: A dictionary containing the resized prompt and additional information. """ + def _extract_messages_from_response(self, message_tag: str, response_body: Dict[str, Any]) -> List[ChatMessage]: + metadata = {k: v for (k, v) in response_body.items() if k != message_tag} + return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] + @abstractmethod - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """Extracts the responses from the Amazon Bedrock response.""" + def response_body_message_key(self) -> str: + """Returns the key for the message in the response body.""" @abstractmethod def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: @@ -171,9 +181,8 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: def check_prompt(self, prompt: str) -> Dict[str, Any]: return self.prompt_handler(prompt) - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - metadata = {k: v for (k, v) in response_body.items() if k != "completion"} - return [ChatMessage.from_assistant(response_body["completion"], meta=metadata)] + def response_body_message_key(self) -> str: + return "completion" def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: return chunk.get("completion", "") @@ -250,9 +259,8 @@ def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: def check_prompt(self, prompt: str) -> Dict[str, Any]: return self.prompt_handler(prompt) - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - metadata = {k: v for (k, v) in response_body.items() if k != "generation"} - return [ChatMessage.from_assistant(response_body["generation"], meta=metadata)] + def response_body_message_key(self) -> str: + return "generation" def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: return chunk.get("generation", "") diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index fda9d4fff..ecb0c7bb9 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -76,7 +76,9 @@ def __init__( Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded automatically from the environment or the AWS configuration file and do not need to be provided explicitly via - the constructor. + the constructor. If the AWS environment is not configured users need to provide the AWS credentials via the + constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`, + and `aws_region_name`. :param model: The model to use for generation. The model must be available in Amazon Bedrock. The model has to be specified in the format outlined in the Amazon Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html).