diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py index 236347b61..2d33beb42 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from .chat.chat_generator import AmazonBedrockChatGenerator from .generator import AmazonBedrockGenerator -__all__ = ["AmazonBedrockGenerator"] +__all__ = ["AmazonBedrockGenerator", "AmazonBedrockChatGenerator"] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py new file mode 100644 index 000000000..a4eefe321 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -0,0 +1,266 @@ +import json +import logging +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List + +from botocore.eventstream import EventStream +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from transformers import AutoTokenizer, PreTrainedTokenizer + +from haystack_integrations.components.generators.amazon_bedrock.handlers import DefaultPromptHandler + +logger = logging.getLogger(__name__) + + +class BedrockModelChatAdapter(ABC): + """ + Base class for Amazon Bedrock chat model adapters. + """ + + def __init__(self, generation_kwargs: Dict[str, Any]) -> None: + self.generation_kwargs = generation_kwargs + + @abstractmethod + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: + """Prepares the body for the Amazon Bedrock request.""" + + def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + """Extracts the responses from the Amazon Bedrock response.""" + return self._extract_messages_from_response(self.response_body_message_key(), response_body) + + def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]: + tokens: List[str] = [] + for event in stream: + chunk = event.get("chunk") + if chunk: + decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) + token = self._extract_token_from_stream(decoded_chunk) + # take all the rest key/value pairs from the chunk, add them to the metadata + stream_metadata = {k: v for (k, v) in decoded_chunk.items() if v != token} + stream_chunk = StreamingChunk(content=token, meta=stream_metadata) + # callback the stream handler with StreamingChunk + stream_handler(stream_chunk) + tokens.append(token) + responses = ["".join(tokens).lstrip()] + return responses + + @staticmethod + def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> None: + """ + Updates target_dict with values from updates_dict. Merges lists instead of overriding them. + + :param target_dict: The dictionary to update. + :param updates_dict: The dictionary with updates. + """ + for key, value in updates_dict.items(): + if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list): + # Merge lists and remove duplicates + target_dict[key] = sorted(set(target_dict[key] + value)) + else: + # Override the value in target_dict + target_dict[key] = value + + def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]: + """ + Merges params from inference_kwargs with the default params and self.generation_kwargs. + Uses a helper function to merge lists or override values as necessary. + + :param inference_kwargs: The inference kwargs to merge. + :param default_params: The default params to start with. + :return: The merged params. + """ + # Start with a copy of default_params + kwargs = default_params.copy() + + # Update the default params with self.generation_kwargs and finally inference_kwargs + self._update_params(kwargs, self.generation_kwargs) + self._update_params(kwargs, inference_kwargs) + + return kwargs + + def _ensure_token_limit(self, prompt: str) -> str: + resize_info = self.check_prompt(prompt) + if resize_info["prompt_length"] != resize_info["new_prompt_length"]: + logger.warning( + "The prompt was truncated from %s tokens to %s tokens so that the prompt length and " + "the answer length (%s tokens) fit within the model's max token limit (%s tokens). " + "Shorten the prompt or it will be cut off.", + resize_info["prompt_length"], + max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore + resize_info["max_length"], + resize_info["model_max_length"], + ) + return str(resize_info["resized_prompt"]) + + @abstractmethod + def check_prompt(self, prompt: str) -> Dict[str, Any]: + """ + Checks the prompt length and resizes it if necessary. + + :param prompt: The prompt to check. + :return: A dictionary containing the resized prompt and additional information. + """ + + def _extract_messages_from_response(self, message_tag: str, response_body: Dict[str, Any]) -> List[ChatMessage]: + metadata = {k: v for (k, v) in response_body.items() if k != message_tag} + return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] + + @abstractmethod + def response_body_message_key(self) -> str: + """Returns the key for the message in the response body.""" + + @abstractmethod + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + """Extracts the token from a streaming chunk.""" + + +class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): + """ + Model adapter for the Anthropic Claude model. + """ + + ANTHROPIC_USER_TOKEN = "\n\nHuman:" + ANTHROPIC_ASSISTANT_TOKEN = "\n\nAssistant:" + + def __init__(self, generation_kwargs: Dict[str, Any]): + super().__init__(generation_kwargs) + + # We pop the model_max_length as it is not sent to the model + # but used to truncate the prompt if needed + # Anthropic Claude has a limit of at least 100000 tokens + # https://docs.anthropic.com/claude/reference/input-and-output-sizes + model_max_length = self.generation_kwargs.pop("model_max_length", 100000) + + # Truncate prompt if prompt tokens > model_max_length-max_length + # (max_length is the length of the generated text) + # TODO use Anthropic tokenizer to get the precise prompt length + # See https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#token-counting + self.prompt_handler = DefaultPromptHandler( + tokenizer="gpt2", + model_max_length=model_max_length, + max_length=self.generation_kwargs.get("max_tokens_to_sample") or 512, + ) + + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: + default_params = { + "max_tokens_to_sample": self.generation_kwargs.get("max_tokens_to_sample") or 512, + "stop_sequences": ["\n\nHuman:"], + } + + # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it + stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) + if stop_sequences: + inference_kwargs["stop_sequences"] = stop_sequences + params = self._get_params(inference_kwargs, default_params) + body = {"prompt": self.prepare_chat_messages(messages=messages), **params} + return body + + def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: + conversation = [] + for index, message in enumerate(messages): + if message.is_from(ChatRole.USER): + conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_USER_TOKEN} {message.content.strip()}") + elif message.is_from(ChatRole.ASSISTANT): + conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN} {message.content.strip()}") + elif message.is_from(ChatRole.FUNCTION): + error_message = "Anthropic does not support function calls." + raise ValueError(error_message) + elif message.is_from(ChatRole.SYSTEM) and index == 0: + # Until we transition to the new chat message format system messages will be ignored + # see https://docs.anthropic.com/claude/reference/messages_post for more details + logger.warning( + "System messages are not fully supported by the current version of Claude and will be ignored." + ) + else: + invalid_role = f"Invalid role {message.role} for message {message.content}" + raise ValueError(invalid_role) + + prepared_prompt = "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " " + return self._ensure_token_limit(prepared_prompt) + + def check_prompt(self, prompt: str) -> Dict[str, Any]: + return self.prompt_handler(prompt) + + def response_body_message_key(self) -> str: + return "completion" + + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + return chunk.get("completion", "") + + +class MetaLlama2ChatAdapter(BedrockModelChatAdapter): + """ + Model adapter for the Meta Llama 2 models. + """ + + chat_template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" + "{% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" + ) + + def __init__(self, generation_kwargs: Dict[str, Any]) -> None: + super().__init__(generation_kwargs) + # We pop the model_max_length as it is not sent to the model + # but used to truncate the prompt if needed + # Llama 2 has context window size of 4096 tokens + # with some exceptions when the context window has been extended + model_max_length = self.generation_kwargs.pop("model_max_length", 4096) + + # Use `google/flan-t5-base` as it's also BPE sentencepiece tokenizer just like llama 2 + # a) we should get good estimates for the prompt length (empirically close to llama 2) + # b) we can use apply_chat_template with the template above to delineate ChatMessages + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") + tokenizer.bos_token = "" + tokenizer.eos_token = "" + tokenizer.unk_token = "" + self.prompt_handler = DefaultPromptHandler( + tokenizer=tokenizer, + model_max_length=model_max_length, + max_length=self.generation_kwargs.get("max_gen_len") or 512, + ) + + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: + default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512} + + # combine stop words with default stop sequences, remove stop_words as MetaLlama2 does not support it + stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) + if stop_sequences: + inference_kwargs["stop_sequences"] = stop_sequences + params = self._get_params(inference_kwargs, default_params) + body = {"prompt": self.prepare_chat_messages(messages=messages), **params} + return body + + def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: + prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( + conversation=messages, tokenize=False, chat_template=self.chat_template + ) + return self._ensure_token_limit(prepared_prompt) + + def check_prompt(self, prompt: str) -> Dict[str, Any]: + return self.prompt_handler(prompt) + + def response_body_message_key(self) -> str: + return "generation" + + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + return chunk.get("generation", "") diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py new file mode 100644 index 000000000..804d44413 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -0,0 +1,249 @@ +import json +import logging +import re +from typing import Any, Callable, ClassVar, Dict, List, Optional, Type + +import boto3 +from botocore.exceptions import BotoCoreError, ClientError +from haystack import component, default_from_dict, default_to_dict +from haystack.components.generators.utils import deserialize_callback_handler +from haystack.dataclasses import ChatMessage, StreamingChunk + +from haystack_integrations.components.generators.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, + AWSConfigurationError, +) + +from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter + +logger = logging.getLogger(__name__) + +AWS_CONFIGURATION_KEYS = [ + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region_name", + "aws_profile_name", +] + + +@component +class AmazonBedrockChatGenerator: + """ + AmazonBedrockChatGenerator enables text generation via Amazon Bedrock chat hosted models. For example, to use + the Anthropic Claude model, simply initialize the AmazonBedrockChatGenerator with the 'anthropic.claude-v2' + model name. + + ```python + from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator + from haystack.dataclasses import ChatMessage + from haystack.components.generators.utils import print_streaming_chunk + + messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"), + ChatMessage.from_user("What's Natural Language Processing?")] + + + client = AmazonBedrockChatGenerator(model="anthropic.claude-v2", streaming_callback=print_streaming_chunk) + client.run(messages, generation_kwargs={"max_tokens_to_sample": 512}) + + ``` + + If you prefer non-streaming mode, simply remove the `streaming_callback` parameter, capture the return value of the + component's run method and the AmazonBedrockChatGenerator will return the response in a non-streaming mode. + """ + + SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { + r"anthropic.claude.*": AnthropicClaudeChatAdapter, + r"meta.llama2.*": MetaLlama2ChatAdapter, + } + + def __init__( + self, + model: str, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + stop_words: Optional[List[str]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): + """ + Initializes the AmazonBedrockChatGenerator with the provided parameters. The parameters are passed to the + Amazon Bedrock client. + + Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded + automatically from the environment or the AWS configuration file and do not need to be provided explicitly via + the constructor. If the AWS environment is not configured users need to provide the AWS credentials via the + constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`, + and `aws_region_name`. + + :param model: The model to use for generation. The model must be available in Amazon Bedrock. The model has to + be specified in the format outlined in the Amazon Bedrock [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html). + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. + :param aws_profile_name: AWS profile name. + :param generation_kwargs: Additional generation keyword arguments passed to the model. The defined keyword + parameters are specific to a specific model and can be found in the model's documentation. For example, the + Anthropic Claude generation parameters can be found [here](https://docs.anthropic.com/claude/reference/complete_post). + :param stop_words: A list of stop words that stop model generation when encountered. They can be provided via + this parameter or via models generation_kwargs under a model's specific key for stop words. For example, the + Anthropic Claude stop words are provided via the `stop_sequences` key. + :param streaming_callback: A callback function that is called when a new chunk is received from the stream. + By default, the model is not set up for streaming. To enable streaming simply set this parameter to a callback + function that will handle the streaming chunks. The callback function will receive a StreamingChunk object and + switch the streaming mode on. + """ + if not model: + msg = "'model' cannot be None or empty string" + raise ValueError(msg) + self.model = model + + # get the model adapter for the given model + model_adapter_cls = self.get_model_adapter(model=model) + if not model_adapter_cls: + msg = f"AmazonBedrockGenerator doesn't support the model {model}." + raise AmazonBedrockConfigurationError(msg) + self.model_adapter = model_adapter_cls(generation_kwargs or {}) + + # create the AWS session and client + try: + session = self.get_aws_session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_region_name=aws_region_name, + aws_profile_name=aws_profile_name, + ) + self.client = session.client("bedrock-runtime") + except Exception as exception: + msg = ( + "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " + "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" + ) + raise AmazonBedrockConfigurationError(msg) from exception + + self.stop_words = stop_words or [] + self.streaming_callback = streaming_callback + + def invoke(self, *args, **kwargs): + kwargs = kwargs.copy() + messages: List[ChatMessage] = kwargs.pop("messages", []) + # check if the prompt is a list of ChatMessage objects + if not ( + isinstance(messages, list) + and len(messages) > 0 + and all(isinstance(message, ChatMessage) for message in messages) + ): + msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." + raise ValueError(msg) + + body = self.model_adapter.prepare_body(messages=messages, stop_words=self.stop_words, **kwargs) + try: + if self.streaming_callback: + response = self.client.invoke_model_with_response_stream( + body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" + ) + response_stream = response["body"] + responses = self.model_adapter.get_stream_responses( + stream=response_stream, stream_handler=self.streaming_callback + ) + else: + response = self.client.invoke_model( + body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" + ) + response_body = json.loads(response.get("body").read().decode("utf-8")) + responses = self.model_adapter.get_responses(response_body=response_body) + except ClientError as exception: + msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" + raise AmazonBedrockInferenceError(msg) from exception + + return responses + + @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) + def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + return {"replies": self.invoke(messages=messages, **(generation_kwargs or {}))} + + @classmethod + def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]]: + for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): + if re.fullmatch(pattern, model): + return adapter + return None + + @classmethod + def aws_configured(cls, **kwargs) -> bool: + """ + Checks whether AWS configuration is provided. + :param kwargs: The kwargs passed down to the generator. + :return: True if AWS configuration is provided, False otherwise. + """ + aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS) + return aws_config_provided + + @classmethod + def get_aws_session( + cls, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + **kwargs, + ): + """ + Creates an AWS Session with the given parameters. + Checks if the provided AWS credentials are valid and can be used to connect to AWS. + + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. + :param aws_profile_name: AWS profile name. + :param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen. + See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html. + :raises AWSConfigurationError: If the provided AWS credentials are invalid. + :return: The created AWS session. + """ + try: + return boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + region_name=aws_region_name, + profile_name=aws_profile_name, + ) + except BotoCoreError as e: + provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} + msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" + raise AWSConfigurationError(msg) from e + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + :return: The serialized component as a dictionary. + """ + return default_to_dict( + self, + model=self.model, + stop_words=self.stop_words, + generation_kwargs=self.model_adapter.generation_kwargs, + streaming_callback=self.streaming_callback, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": + """ + Deserialize this component from a dictionary. + :param data: The dictionary representation of this component. + :return: The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) + return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 4c43c9a09..8e89dab59 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -112,7 +112,7 @@ def __init__( # It is hard to determine which tokenizer to use for the SageMaker model # so we use GPT2 tokenizer which will likely provide good token count approximation self.prompt_handler = DefaultPromptHandler( - model="gpt2", + tokenizer="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100, ) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py index 56dcb24d3..b7b555ec0 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict, Union -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast class DefaultPromptHandler: @@ -10,8 +10,20 @@ class DefaultPromptHandler: are within the model_max_length. """ - def __init__(self, model: str, model_max_length: int, max_length: int = 100): - self.tokenizer = AutoTokenizer.from_pretrained(model) + def __init__(self, tokenizer: Union[str, PreTrainedTokenizerBase], model_max_length: int, max_length: int = 100): + """ + :param tokenizer: The tokenizer to be used to tokenize the prompt. + :param model_max_length: The maximum length of the prompt and answer tokens combined. + :param max_length: The maximum length of the answer tokens. + """ + if isinstance(tokenizer, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + elif isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + self.tokenizer = tokenizer + else: + msg = "model must be a string or a PreTrainedTokenizer instance" + raise ValueError(msg) + self.tokenizer.model_max_length = model_max_length self.model_max_length = model_max_length self.max_length = max_length diff --git a/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py new file mode 100644 index 000000000..9592b5b39 --- /dev/null +++ b/integrations/amazon_bedrock/tests/test_amazon_chat_bedrock.py @@ -0,0 +1,250 @@ +from typing import Optional, Type +from unittest.mock import MagicMock, patch + +import pytest +from haystack.components.generators.utils import print_streaming_chunk +from haystack.dataclasses import ChatMessage + +from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator +from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import ( + AnthropicClaudeChatAdapter, + BedrockModelChatAdapter, + MetaLlama2ChatAdapter, +) + +clazz = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" + + +@pytest.fixture +def mock_auto_tokenizer(): + with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained: + mock_tokenizer = MagicMock() + mock_from_pretrained.return_value = mock_tokenizer + yield mock_tokenizer + + +# create a fixture with mocked boto3 client and session +@pytest.fixture +def mock_boto3_session(): + with patch("boto3.Session") as mock_client: + yield mock_client + + +@pytest.fixture +def mock_prompt_handler(): + with patch( + "haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler" + ) as mock_prompt_handler: + yield mock_prompt_handler + + +def test_to_dict(mock_auto_tokenizer, mock_boto3_session): + """ + Test that the to_dict method returns the correct dictionary without aws credentials + """ + generator = AmazonBedrockChatGenerator( + model="anthropic.claude-v2", + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + aws_profile_name="some_fake_profile", + aws_region_name="fake_region", + generation_kwargs={"temperature": 0.7}, + streaming_callback=print_streaming_chunk, + ) + expected_dict = { + "type": clazz, + "init_parameters": { + "model": "anthropic.claude-v2", + "generation_kwargs": {"temperature": 0.7}, + "stop_words": [], + "streaming_callback": print_streaming_chunk, + }, + } + + assert generator.to_dict() == expected_dict + + +def test_from_dict(mock_auto_tokenizer, mock_boto3_session): + """ + Test that the from_dict method returns the correct object + """ + generator = AmazonBedrockChatGenerator.from_dict( + { + "type": clazz, + "init_parameters": { + "model": "anthropic.claude-v2", + "generation_kwargs": {"temperature": 0.7}, + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + }, + } + ) + assert generator.model == "anthropic.claude-v2" + assert generator.model_adapter.generation_kwargs == {"temperature": 0.7} + assert generator.streaming_callback == print_streaming_chunk + + +def test_default_constructor(mock_auto_tokenizer, mock_boto3_session): + """ + Test that the default constructor sets the correct values + """ + + layer = AmazonBedrockChatGenerator( + model="anthropic.claude-v2", + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + aws_profile_name="some_fake_profile", + aws_region_name="fake_region", + ) + + assert layer.model == "anthropic.claude-v2" + + assert layer.model_adapter.prompt_handler is not None + assert layer.model_adapter.prompt_handler.model_max_length == 100000 + + # assert mocked boto3 client called exactly once + mock_boto3_session.assert_called_once() + + # assert mocked boto3 client was called with the correct parameters + mock_boto3_session.assert_called_with( + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + profile_name="some_fake_profile", + region_name="fake_region", + ) + + +def test_constructor_with_generation_kwargs(mock_auto_tokenizer, mock_boto3_session): + """ + Test that model_kwargs are correctly set in the constructor + """ + generation_kwargs = {"temperature": 0.7} + + layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", generation_kwargs=generation_kwargs) + assert "temperature" in layer.model_adapter.generation_kwargs + assert layer.model_adapter.generation_kwargs["temperature"] == 0.7 + + +def test_constructor_with_empty_model(): + """ + Test that the constructor raises an error when the model is empty + """ + with pytest.raises(ValueError, match="cannot be None or empty string"): + AmazonBedrockChatGenerator(model="") + + +@pytest.mark.unit +def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session): + """ + Test invoke raises an error if no messages are provided + """ + layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2") + with pytest.raises(ValueError, match="The model anthropic.claude-v2 requires"): + layer.invoke() + + +@pytest.mark.unit +@pytest.mark.parametrize( + "model, expected_model_adapter", + [ + ("anthropic.claude-v1", AnthropicClaudeChatAdapter), + ("anthropic.claude-v2", AnthropicClaudeChatAdapter), + ("anthropic.claude-instant-v1", AnthropicClaudeChatAdapter), + ("anthropic.claude-super-v5", AnthropicClaudeChatAdapter), # artificial + ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), + ("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), + ("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial + ("unknown_model", None), + ], +) +def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelChatAdapter]]): + """ + Test that the correct model adapter is returned for a given model + """ + model_adapter = AmazonBedrockChatGenerator.get_model_adapter(model=model) + assert model_adapter == expected_model_adapter + + +class TestAnthropicClaudeAdapter: + def test_prepare_body_with_default_params(self, mock_auto_tokenizer) -> None: + layer = AnthropicClaudeChatAdapter(generation_kwargs={}) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ", + "max_tokens_to_sample": 512, + "stop_sequences": ["\n\nHuman:"], + } + + body = layer.prepare_body([ChatMessage.from_user(prompt)]) + + assert body == expected_body + + def test_prepare_body_with_custom_inference_params(self, mock_auto_tokenizer) -> None: + layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ", + "max_tokens_to_sample": 69, + "stop_sequences": ["\n\nHuman:", "CUSTOM_STOP"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + } + + body = layer.prepare_body( + [ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69, stop_sequences=["CUSTOM_STOP"] + ) + + assert body == expected_body + + +class TestMetaLlama2ChatAdapter: + @pytest.mark.integration + def test_prepare_body_with_default_params(self) -> None: + # leave this test as integration because we really need only tokenizer from HF + # that way we can ensure prompt chat message formatting + layer = MetaLlama2ChatAdapter(generation_kwargs={}) + prompt = "Hello, how are you?" + expected_body = {"prompt": "[INST] Hello, how are you? [/INST]", "max_gen_len": 512} + + body = layer.prepare_body([ChatMessage.from_user(prompt)]) + + assert body == expected_body + + @pytest.mark.integration + def test_prepare_body_with_custom_inference_params(self) -> None: + # leave this test as integration because we really need only tokenizer from HF + # that way we can ensure prompt chat message formatting + layer = MetaLlama2ChatAdapter( + generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]} + ) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "[INST] Hello, how are you? [/INST]", + "max_gen_len": 69, + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + } + + body = layer.prepare_body( + [ChatMessage.from_user(prompt)], + temperature=0.7, + top_p=0.8, + top_k=5, + max_gen_len=69, + stop_sequences=["CUSTOM_STOP"], + ) + + assert body == expected_body + + @pytest.mark.integration + def test_get_responses(self) -> None: + adapter = MetaLlama2ChatAdapter(generation_kwargs={}) + response_body = {"generation": "This is a single response."} + expected_response = "This is a single response." + response_message = adapter.get_responses(response_body) + assert response_message == [ChatMessage.from_assistant(expected_response)]