From df14a979e05879a0165f1dbe7c2737a063c188d2 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Dec 2024 15:40:27 +0100 Subject: [PATCH] feat: Update AmazonBedrockChatGenerator to use Converse API (BREAKING CHANGE) (#1219) * Initial commit * Update models tested * Add tool support * Update Amazon Bedrock model names in tests * Support for tool streaming * Format * Minot test updates * Lint * Remove truncate init parameter * Pull try down * Add extract_replies_from_response unit test * Add process_streaming_response unit test * Lint * Small test fix * Use EventStream from botocore * Update integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py Co-authored-by: Stefano Fiorucci * Update integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py Co-authored-by: Stefano Fiorucci --------- Co-authored-by: Stefano Fiorucci --- .../amazon_bedrock/chat/adapters.py | 569 -------------- .../amazon_bedrock/chat/chat_generator.py | 294 +++++--- .../tests/test_chat_generator.py | 704 +++++++----------- 3 files changed, 467 insertions(+), 1100 deletions(-) delete mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py deleted file mode 100644 index cbb5ee370..000000000 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ /dev/null @@ -1,569 +0,0 @@ -import json -import logging -import os -from abc import ABC, abstractmethod -from typing import Any, Callable, ClassVar, Dict, List, Optional - -from botocore.eventstream import EventStream -from haystack.components.generators.openai_utils import _convert_message_to_openai_format -from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk -from transformers import AutoTokenizer, PreTrainedTokenizer - -from haystack_integrations.components.generators.amazon_bedrock.handlers import DefaultPromptHandler - -logger = logging.getLogger(__name__) - - -class BedrockModelChatAdapter(ABC): - """ - Base class for Amazon Bedrock chat model adapters. - - Each subclass of this class is designed to address the unique specificities of a particular chat LLM it adapts, - focusing on preparing the requests and extracting the responses from the Amazon Bedrock hosted chat LLMs. - """ - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: - """ - Initializes the chat adapter with the truncate parameter and generation kwargs. - """ - self.generation_kwargs = generation_kwargs - self.truncate = truncate - - @abstractmethod - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Amazon Bedrock request. - Subclasses should override this method to package the chat messages into the request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - :returns: The prepared body. - """ - - def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the responses from the Amazon Bedrock response. - - :param response_body: The response body. - :returns: The extracted responses. - """ - return self._extract_messages_from_response(response_body) - - def get_stream_responses( - self, stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] - ) -> List[ChatMessage]: - streaming_chunks: List[StreamingChunk] = [] - last_decoded_chunk: Dict[str, Any] = {} - for event in stream: - chunk = event.get("chunk") - if chunk: - last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) - streaming_chunk = self._build_streaming_chunk(last_decoded_chunk) - streaming_callback(streaming_chunk) # callback the stream handler with StreamingChunk - streaming_chunks.append(streaming_chunk) - responses = ["".join(chunk.content for chunk in streaming_chunks).lstrip()] - return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses] - - @staticmethod - def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any], allowed_params: List[str]) -> None: - """ - Updates target_dict with values from updates_dict. Merges lists instead of overriding them. - - :param target_dict: The dictionary to update. - :param updates_dict: The dictionary with updates. - :param allowed_params: The list of allowed params to use. - """ - for key, value in updates_dict.items(): - if key not in allowed_params: - logger.warning(f"Parameter '{key}' is not allowed and will be ignored.") - continue - if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list): - # Merge lists and remove duplicates - target_dict[key] = sorted(set(target_dict[key] + value)) - else: - # Override the value in target_dict - target_dict[key] = value - - def _get_params( - self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any], allowed_params: List[str] - ) -> Dict[str, Any]: - """ - Merges params from inference_kwargs with the default params and self.generation_kwargs. - Uses a helper function to merge lists or override values as necessary. - - :param inference_kwargs: The inference kwargs to merge. - :param default_params: The default params to start with. - :param allowed_params: The list of allowed params to use. - :returns: The merged params. - """ - # Start with a copy of default_params - kwargs = default_params.copy() - - # Update the default params with self.generation_kwargs and finally inference_kwargs - self._update_params(kwargs, self.generation_kwargs, allowed_params) - self._update_params(kwargs, inference_kwargs, allowed_params) - - return kwargs - - def _ensure_token_limit(self, prompt: str) -> str: - """ - Ensures that the prompt is within the token limit for the model. - :param prompt: The prompt to check. - :returns: The resized prompt. - """ - resize_info = self.check_prompt(prompt) - if resize_info["prompt_length"] != resize_info["new_prompt_length"]: - logger.warning( - "The prompt was truncated from %s tokens to %s tokens so that the prompt length and " - "the answer length (%s tokens) fit within the model's max token limit (%s tokens). " - "Shorten the prompt or it will be cut off.", - resize_info["prompt_length"], - max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore - resize_info["max_length"], - resize_info["model_max_length"], - ) - return str(resize_info["resized_prompt"]) - - @abstractmethod - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - """ - - @abstractmethod - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :returns: The extracted ChatMessage list. - """ - - @abstractmethod - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - - -class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): - """ - Model adapter for the Anthropic Claude chat model. - """ - - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - ALLOWED_PARAMS: ClassVar[List[str]] = [ - "anthropic_version", - "max_tokens", - "stop_sequences", - "temperature", - "top_p", - "top_k", - "system", - "tools", - "tool_choice", - ] - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): - """ - Initializes the Anthropic Claude chat adapter. - - :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. - :param generation_kwargs: The generation kwargs. - """ - super().__init__(truncate, generation_kwargs) - - # We pop the model_max_length as it is not sent to the model - # but used to truncate the prompt if needed - # Anthropic Claude has a limit of at least 100000 tokens - # https://docs.anthropic.com/claude/reference/input-and-output-sizes - model_max_length = self.generation_kwargs.pop("model_max_length", 100000) - - # Truncate prompt if prompt tokens > model_max_length-max_length - # (max_length is the length of the generated text) - # TODO use Anthropic tokenizer to get the precise prompt length - # See https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#token-counting - self.prompt_handler = DefaultPromptHandler( - tokenizer="gpt2", - model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_tokens") or 512, - ) - - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Anthropic Claude request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - :returns: The prepared body. - """ - default_params = { - "anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31", - "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required - } - - # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it - stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) - if stop_sequences: - inference_kwargs["stop_sequences"] = stop_sequences - # pop stream kwarg from inference_kwargs as Anthropic does not support it (if provided) - inference_kwargs.pop("stream", None) - params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) - body = {**self.prepare_chat_messages(messages=messages), **params} - return body - - def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: - """ - Prepares the chat messages for the Anthropic Claude request. - - :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a dictionary. - """ - body: Dict[str, Any] = {} - system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None - body["messages"] = [ - self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT) - ] - if system: - body["system"] = system - # Ensure token limit for each message in the body - if self.truncate: - for message in body["messages"]: - for content in message["content"]: - content["text"] = self._ensure_token_limit(content["text"]) - return body - - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - """ - return self.prompt_handler(prompt) - - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :return: The extracted ChatMessage list. - """ - messages: List[ChatMessage] = [] - if response_body.get("type") == "message": - if response_body.get("stop_reason") == "tool_use": # If `tool_use` we only keep the tool_use content - for content in response_body["content"]: - if content.get("type") == "tool_use": - meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} - json_answer = json.dumps(content) - messages.append(ChatMessage.from_assistant(json_answer, meta=meta)) - else: # For other stop_reason, return all text content - for content in response_body["content"]: - if content.get("type") == "text": - meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} - messages.append(ChatMessage.from_assistant(content["text"], meta=meta)) - - return messages - - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta": - return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk) - return StreamingChunk(content="", meta=chunk) - - def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]: - """ - Convert a ChatMessage to a dictionary with the content and role fields. - :param m: The ChatMessage to convert. - :return: The dictionary with the content and role fields. - """ - return {"content": [{"type": "text", "text": m.content}], "role": m.role.value} - - -class MistralChatAdapter(BedrockModelChatAdapter): - """ - Model adapter for the Mistral chat model. - """ - - chat_template = """ - {% if messages[0]['role'] == 'system' %} - {% set loop_messages = messages[1:] %} - {% set system_message = messages[0]['content'] %} - {% else %} - {% set loop_messages = messages %} - {% set system_message = false %} - {% endif %} - {{bos_token}} - {% for message in loop_messages %} - {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} - {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} - {% endif %} - {% if loop.index0 == 0 and system_message != false %} - {% set content = system_message + '\n' + message['content'] %} - {% else %} - {% set content = message['content'] %} - {% endif %} - {% if message['role'] == 'user' %} - {{ '[INST] ' + content.strip() + ' [/INST]' }} - {% elif message['role'] == 'assistant' %} - {{ content.strip() + eos_token }} - {% endif %} - {% endfor %} - """ - chat_template = "".join(line.strip() for line in chat_template.splitlines()) - - # the above template was designed to match https://docs.mistral.ai/models/#chat-template - # and to support system messages, otherwise we could use the default mistral chat template - # available on HF infrastructure - - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - ALLOWED_PARAMS: ClassVar[List[str]] = [ - "max_tokens", - "safe_prompt", - "random_seed", - "temperature", - "top_p", - ] - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): - """ - Initializes the Mistral chat adapter. - :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. - :param generation_kwargs: The generation kwargs. - """ - super().__init__(truncate, generation_kwargs) - - # We pop the model_max_length as it is not sent to the model - # but used to truncate the prompt if needed - # Mistral has a limit of at least 32000 tokens - model_max_length = self.generation_kwargs.pop("model_max_length", 32000) - - # Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer - # a) we should get good estimates for the prompt length - # b) we can use apply_chat_template with the template above to delineate ChatMessages - # Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model. - tokenizer: PreTrainedTokenizer - if os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN"): - tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") - else: - tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") - logger.warning( - "Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for " - "estimating the prompt length because no HF_TOKEN was found. Using " - "NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var " - "HF_TOKEN containing a Hugging Face token and make sure you have access to the model." - ) - - self.prompt_handler = DefaultPromptHandler( - tokenizer=tokenizer, - model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_tokens") or 512, - ) - - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Mistral request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - :returns: The prepared body. - """ - default_params = { - "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required - } - # replace stop_words from inference_kwargs with stop, as this is Mistral specific parameter - stop_words = inference_kwargs.pop("stop_words", []) - if stop_words: - inference_kwargs["stop"] = stop_words - - # pop stream kwarg from inference_kwargs as Mistral does not support it (if provided) - inference_kwargs.pop("stream", None) - - params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) - body = {"prompt": self.prepare_chat_messages(messages=messages), **params} - return body - - def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: - """ - Prepares the chat messages for the Mistral request. - - :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a string. - """ - # it would be great to use the default mistral chat template, but it doesn't support system messages - # the class variable defined chat_template is a workaround to support system messages - # default is https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json - # but we'll use our custom chat template - prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=[_convert_message_to_openai_format(m) for m in messages], - tokenize=False, - chat_template=self.chat_template, - ) - if self.truncate: - prepared_prompt = self._ensure_token_limit(prepared_prompt) - return prepared_prompt - - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - """ - return self.prompt_handler(prompt) - - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :return: The extracted ChatMessage list. - """ - messages: List[ChatMessage] = [] - responses = response_body.get("outputs", []) - for response in responses: - meta = {k: v for k, v in response.items() if k not in ["text"]} - messages.append(ChatMessage.from_assistant(response["text"], meta=meta)) - return messages - - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - response_chunk = chunk.get("outputs", []) - if response_chunk: - return StreamingChunk(content=response_chunk[0].get("text", ""), meta=chunk) - return StreamingChunk(content="", meta=chunk) - - -class MetaLlama2ChatAdapter(BedrockModelChatAdapter): - """ - Model adapter for the Meta Llama 2 models. - """ - - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html - ALLOWED_PARAMS: ClassVar[List[str]] = ["max_gen_len", "temperature", "top_p"] - - chat_template = ( - "{% if messages[0]['role'] == 'system' %}" - "{% set loop_messages = messages[1:] %}" - "{% set system_message = messages[0]['content'] %}" - "{% else %}" - "{% set loop_messages = messages %}" - "{% set system_message = false %}" - "{% endif %}" - "{% for message in loop_messages %}" - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if loop.index0 == 0 and system_message != false %}" - "{% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %}" - "{% else %}" - "{% set content = message['content'] %}" - "{% endif %}" - "{% if message['role'] == 'user' %}" - "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" - "{% elif message['role'] == 'assistant' %}" - "{{ ' ' + content.strip() + ' ' + eos_token }}" - "{% endif %}" - "{% endfor %}" - ) - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: - """ - Initializes the Meta Llama 2 chat adapter. - :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. - :param generation_kwargs: The generation kwargs. - """ - super().__init__(truncate, generation_kwargs) - # We pop the model_max_length as it is not sent to the model - # but used to truncate the prompt if needed - # Llama 2 has context window size of 4096 tokens - # with some exceptions when the context window has been extended - model_max_length = self.generation_kwargs.pop("model_max_length", 4096) - - # Use `google/flan-t5-base` as it's also BPE sentencepiece tokenizer just like llama 2 - # a) we should get good estimates for the prompt length (empirically close to llama 2) - # b) we can use apply_chat_template with the template above to delineate ChatMessages - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") - tokenizer.bos_token = "" - tokenizer.eos_token = "" - tokenizer.unk_token = "" - self.prompt_handler = DefaultPromptHandler( - tokenizer=tokenizer, - model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_gen_len") or 512, - ) - - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Meta Llama 2 request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - """ - default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512} - - # no support for stop words in Meta Llama 2 - params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) - body = {"prompt": self.prepare_chat_messages(messages=messages), **params} - return body - - def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: - """ - Prepares the chat messages for the Meta Llama 2 request. - - :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a string ready for the model. - """ - prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=messages, tokenize=False, chat_template=self.chat_template - ) - - if self.truncate: - prepared_prompt = self._ensure_token_limit(prepared_prompt) - return prepared_prompt - - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - - """ - return self.prompt_handler(prompt) - - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :return: The extracted ChatMessage list. - """ - message_tag = "generation" - metadata = {k: v for (k, v) in response_body.items() if k != message_tag} - return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] - - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - return StreamingChunk(content=chunk.get("generation", ""), meta=chunk) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 183198bce..499fe1c24 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -1,12 +1,12 @@ import json import logging -import re -from typing import Any, Callable, ClassVar, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional from botocore.config import Config +from botocore.eventstream import EventStream from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable @@ -16,18 +16,16 @@ ) from haystack_integrations.common.amazon_bedrock.utils import get_aws_session -from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter, MistralChatAdapter - logger = logging.getLogger(__name__) @component class AmazonBedrockChatGenerator: """ - Completes chats using LLMs hosted on Amazon Bedrock. + Completes chats using LLMs hosted on Amazon Bedrock available via the Bedrock Converse API. For example, to use the Anthropic Claude 3 Sonnet model, initialize this component with the - 'anthropic.claude-3-sonnet-20240229-v1:0' model name. + 'anthropic.claude-3-5-sonnet-20240620-v1:0' model name. ### Usage example @@ -40,7 +38,7 @@ class AmazonBedrockChatGenerator: ChatMessage.from_user("What's Natural Language Processing?")] - client = AmazonBedrockChatGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", + client = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0", streaming_callback=print_streaming_chunk) client.run(messages, generation_kwargs={"max_tokens": 512}) @@ -58,12 +56,6 @@ class AmazonBedrockChatGenerator: supports Amazon Bedrock. """ - SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { - r"([a-z]{2}\.)?anthropic.claude.*": AnthropicClaudeChatAdapter, - r"([a-z]{2}\.)?meta.llama2.*": MetaLlama2ChatAdapter, - r"([a-z]{2}\.)?mistral.*": MistralChatAdapter, - } - def __init__( self, model: str, @@ -77,7 +69,6 @@ def __init__( generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - truncate: Optional[bool] = True, boto3_config: Optional[Dict[str, Any]] = None, ): """ @@ -111,7 +102,6 @@ def __init__( function that handles the streaming chunks. The callback function receives a [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and switches the streaming mode on. - :param truncate: Whether to truncate the prompt messages or not. :param boto3_config: The configuration for the boto3 client. :raises ValueError: If the model name is empty or None. @@ -129,17 +119,8 @@ def __init__( self.aws_profile_name = aws_profile_name self.stop_words = stop_words or [] self.streaming_callback = streaming_callback - self.truncate = truncate self.boto3_config = boto3_config - # get the model adapter for the given model - model_adapter_cls = self.get_model_adapter(model=model) - if not model_adapter_cls: - msg = f"AmazonBedrockGenerator doesn't support the model {model}." - raise AmazonBedrockConfigurationError(msg) - self.model_adapter = model_adapter_cls(self.truncate, generation_kwargs or {}) - - # create the AWS session and client def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -162,89 +143,9 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: ) raise AmazonBedrockConfigurationError(msg) from exception - @component.output_types(replies=List[ChatMessage]) - def run( - self, - messages: List[ChatMessage], - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - generation_kwargs: Optional[Dict[str, Any]] = None, - ): - """ - Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. - - :param messages: The messages to generate a response to. - :param streaming_callback: - A callback function that is called when a new token is received from the stream. - :param generation_kwargs: Additional generation keyword arguments passed to the model. - :returns: A dictionary with the following keys: - - `replies`: The generated List of `ChatMessage` objects. - """ - generation_kwargs = generation_kwargs or {} - generation_kwargs = generation_kwargs.copy() - - streaming_callback = streaming_callback or self.streaming_callback - generation_kwargs["stream"] = streaming_callback is not None - - # check if the prompt is a list of ChatMessage objects - if not ( - isinstance(messages, list) - and len(messages) > 0 - and all(isinstance(message, ChatMessage) for message in messages) - ): - msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." - raise ValueError(msg) - - body = self.model_adapter.prepare_body( - messages=messages, **{"stop_words": self.stop_words, **generation_kwargs} - ) - try: - if streaming_callback: - response = self.client.invoke_model_with_response_stream( - body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" - ) - response_stream = response["body"] - replies = self.model_adapter.get_stream_responses( - stream=response_stream, streaming_callback=streaming_callback - ) - else: - response = self.client.invoke_model( - body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" - ) - response_body = json.loads(response.get("body").read().decode("utf-8")) - replies = self.model_adapter.get_responses(response_body=response_body) - except ClientError as exception: - msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" - raise AmazonBedrockInferenceError(msg) from exception - - # rename the meta key to be inline with OpenAI meta output keys - for response in replies: - if response.meta: - if "usage" in response.meta: - if "input_tokens" in response.meta["usage"]: - response.meta["usage"]["prompt_tokens"] = response.meta["usage"].pop("input_tokens") - if "output_tokens" in response.meta["usage"]: - response.meta["usage"]["completion_tokens"] = response.meta["usage"].pop("output_tokens") - else: - response.meta["usage"] = {} - if "prompt_token_count" in response.meta: - response.meta["usage"]["prompt_tokens"] = response.meta.pop("prompt_token_count") - if "generation_token_count" in response.meta: - response.meta["usage"]["completion_tokens"] = response.meta.pop("generation_token_count") - - return {"replies": replies} - - @classmethod - def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]]: - """ - Returns the model adapter for the given model. - - :param model: The model to get the adapter for. - :returns: The model adapter for the given model, or None if the model is not supported. - """ - for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): - if re.fullmatch(pattern, model): - return adapter - return None + self.generation_kwargs = generation_kwargs or {} + self.stop_words = stop_words or [] + self.streaming_callback = streaming_callback def to_dict(self) -> Dict[str, Any]: """ @@ -263,9 +164,8 @@ def to_dict(self) -> Dict[str, Any]: aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, stop_words=self.stop_words, - generation_kwargs=self.model_adapter.generation_kwargs, + generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, - truncate=self.truncate, boto3_config=self.boto3_config, ) @@ -274,10 +174,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": """ Deserializes the component from a dictionary. - :param data: - Dictionary to deserialize from. + :param data: Dictionary with serialized data. :returns: - Deserialized component. + Instance of `AmazonBedrockChatGenerator`. """ init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") @@ -288,3 +187,172 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) return default_from_dict(cls, data) + + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + generation_kwargs = generation_kwargs or {} + + # Merge generation_kwargs with defaults + merged_kwargs = self.generation_kwargs.copy() + merged_kwargs.update(generation_kwargs) + + # Extract known inference parameters + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html + inference_config = { + key: merged_kwargs.pop(key, None) + for key in ["maxTokens", "stopSequences", "temperature", "topP"] + if key in merged_kwargs + } + + # Extract tool configuration if present + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html + tool_config = merged_kwargs.pop("toolConfig", None) + + # Any remaining kwargs go to additionalModelRequestFields + additional_fields = merged_kwargs if merged_kwargs else None + + # Prepare system prompts and messages + system_prompts = [] + if messages and messages[0].is_from(ChatRole.SYSTEM): + system_prompts = [{"text": messages[0].text}] + messages = messages[1:] + + messages_list = [{"role": msg.role.value, "content": [{"text": msg.text}]} for msg in messages] + + # Build API parameters + params = { + "modelId": self.model, + "messages": messages_list, + "system": system_prompts, + "inferenceConfig": inference_config, + } + if tool_config: + params["toolConfig"] = tool_config + if additional_fields: + params["additionalModelRequestFields"] = additional_fields + + callback = streaming_callback or self.streaming_callback + + try: + if callback: + response = self.client.converse_stream(**params) + response_stream: EventStream = response.get("stream") + if not response_stream: + msg = "No stream found in the response." + raise AmazonBedrockInferenceError(msg) + replies = self.process_streaming_response(response_stream, callback) + else: + response = self.client.converse(**params) + replies = self.extract_replies_from_response(response) + except ClientError as exception: + msg = f"Could not generate inference for Amazon Bedrock model {self.model} due: {exception}" + raise AmazonBedrockInferenceError(msg) from exception + + return {"replies": replies} + + def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + replies = [] + if "output" in response_body and "message" in response_body["output"]: + message = response_body["output"]["message"] + if message["role"] == "assistant": + content_blocks = message["content"] + + # Common meta information + base_meta = { + "model": self.model, + "index": 0, + "finish_reason": response_body.get("stopReason"), + "usage": { + # OpenAI's format for usage for cross ChatGenerator compatibility + "prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0), + "completion_tokens": response_body.get("usage", {}).get("outputTokens", 0), + "total_tokens": response_body.get("usage", {}).get("totalTokens", 0), + }, + } + + # Process each content block separately + for content_block in content_blocks: + if "text" in content_block: + replies.append(ChatMessage.from_assistant(content=content_block["text"], meta=base_meta.copy())) + elif "toolUse" in content_block: + replies.append( + ChatMessage.from_assistant( + content=json.dumps(content_block["toolUse"]), meta=base_meta.copy() + ) + ) + return replies + + def process_streaming_response( + self, response_stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] + ) -> List[ChatMessage]: + replies = [] + current_content = "" + current_tool_use = None + base_meta = { + "model": self.model, + "index": 0, + } + + for event in response_stream: + if "contentBlockStart" in event: + # Reset accumulators for new message + current_content = "" + current_tool_use = None + block_start = event["contentBlockStart"] + if "start" in block_start and "toolUse" in block_start["start"]: + tool_start = block_start["start"]["toolUse"] + current_tool_use = { + "toolUseId": tool_start["toolUseId"], + "name": tool_start["name"], + "input": "", # Will accumulate deltas as string + } + + elif "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + delta_text = delta["text"] + current_content += delta_text + streaming_chunk = StreamingChunk(content=delta_text, meta=None) + # it only makes sense to call callback on text deltas + streaming_callback(streaming_chunk) + elif "toolUse" in delta and current_tool_use: + # Accumulate tool use input deltas + current_tool_use["input"] += delta["toolUse"].get("input", "") + elif "contentBlockStop" in event: + if current_tool_use: + # Parse accumulated input if it's a JSON string + try: + input_json = json.loads(current_tool_use["input"]) + current_tool_use["input"] = input_json + except json.JSONDecodeError: + # Keep as string if not valid JSON + pass + + tool_content = json.dumps(current_tool_use) + replies.append(ChatMessage.from_assistant(content=tool_content, meta=base_meta.copy())) + elif current_content: + replies.append(ChatMessage.from_assistant(content=current_content, meta=base_meta.copy())) + + elif "messageStop" in event: + # not 100% correct for multiple messages but no way around it + for reply in replies: + reply.meta["finish_reason"] = event["messageStop"].get("stopReason") + + elif "metadata" in event: + metadata = event["metadata"] + # not 100% correct for multiple messages but no way around it + for reply in replies: + if "usage" in metadata: + usage = metadata["usage"] + reply.meta["usage"] = { + "prompt_tokens": usage.get("inputTokens", 0), + "completion_tokens": usage.get("outputTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + } + + return replies diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 8d6a5c3ee..8eb29729c 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,30 +1,36 @@ import json -import logging -import os -from typing import Any, Dict, Optional, Type -from unittest.mock import MagicMock, patch +from typing import Any, Dict, Optional import pytest from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator -from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import ( - AnthropicClaudeChatAdapter, - BedrockModelChatAdapter, - MetaLlama2ChatAdapter, - MistralChatAdapter, -) KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" -MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"] -MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-haiku-20240307-v1:0"] -MISTRAL_MODELS = [ - "mistral.mistral-7b-instruct-v0:2", - "mistral.mixtral-8x7b-instruct-v0:1", +MODELS_TO_TEST = [ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "cohere.command-r-plus-v1:0", + "mistral.mistral-large-2402-v1:0", +] +MODELS_TO_TEST_WITH_TOOLS = [ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0", ] +# so far we've discovered these models support streaming and tool use +STREAMING_TOOL_MODELS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0"] + + +@pytest.fixture +def chat_messages(): + messages = [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + return messages + @pytest.mark.parametrize( "boto3_config", @@ -35,12 +41,12 @@ }, ], ) -def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]): +def test_to_dict(mock_boto3_session, boto3_config): """ Test that the to_dict method returns the correct dictionary without aws credentials """ generator = AmazonBedrockChatGenerator( - model="anthropic.claude-v2", + model="cohere.command-r-plus-v1:0", generation_kwargs={"temperature": 0.7}, streaming_callback=print_streaming_chunk, boto3_config=boto3_config, @@ -53,11 +59,10 @@ def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]] "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, - "model": "anthropic.claude-v2", + "model": "cohere.command-r-plus-v1:0", "generation_kwargs": {"temperature": 0.7}, "stop_words": [], "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "truncate": True, "boto3_config": boto3_config, }, } @@ -87,16 +92,14 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, - "model": "anthropic.claude-v2", + "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "truncate": True, "boto3_config": boto3_config, }, } ) - assert generator.model == "anthropic.claude-v2" - assert generator.model_adapter.generation_kwargs == {"temperature": 0.7} + assert generator.model == "anthropic.claude-3-5-sonnet-20240620-v1:0" assert generator.streaming_callback == print_streaming_chunk assert generator.boto3_config == boto3_config @@ -107,13 +110,10 @@ def test_default_constructor(mock_boto3_session, set_env_variables): """ layer = AmazonBedrockChatGenerator( - model="anthropic.claude-v2", + model="anthropic.claude-3-5-sonnet-20240620-v1:0", ) - assert layer.model == "anthropic.claude-v2" - assert layer.truncate is True - assert layer.model_adapter.prompt_handler is not None - assert layer.model_adapter.prompt_handler.model_max_length == 100000 + assert layer.model == "anthropic.claude-3-5-sonnet-20240620-v1:0" # assert mocked boto3 client called exactly once mock_boto3_session.assert_called_once() @@ -134,18 +134,10 @@ def test_constructor_with_generation_kwargs(mock_boto3_session): """ generation_kwargs = {"temperature": 0.7} - layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", generation_kwargs=generation_kwargs) - assert "temperature" in layer.model_adapter.generation_kwargs - assert layer.model_adapter.generation_kwargs["temperature"] == 0.7 - assert layer.model_adapter.truncate is True - - -def test_constructor_with_truncate(mock_boto3_session): - """ - Test that truncate param is correctly set in the model constructor - """ - layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", truncate=False) - assert layer.model_adapter.truncate is False + layer = AmazonBedrockChatGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", generation_kwargs=generation_kwargs + ) + assert layer.generation_kwargs == generation_kwargs def test_constructor_with_empty_model(): @@ -156,208 +148,15 @@ def test_constructor_with_empty_model(): AmazonBedrockChatGenerator(model="") -def test_short_prompt_is_not_truncated(mock_boto3_session): - """ - Test that a short prompt is not truncated - """ - # Define a short mock prompt and its tokenized version - mock_prompt_text = "I am a tokenized prompt" - mock_prompt_tokens = mock_prompt_text.split() - - # Mock the tokenizer so it returns our predefined tokens - mock_tokenizer = MagicMock() - mock_tokenizer.tokenize.return_value = mock_prompt_tokens - - # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens - # Since our mock prompt is 5 tokens long, it doesn't exceed the - # total limit (5 prompt tokens + 3 generated tokens < 10 tokens) - max_length_generated_text = 3 - total_model_max_length = 10 - - with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): - layer = AmazonBedrockChatGenerator( - "anthropic.claude-v2", - generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, - ) - prompt_after_resize = layer.model_adapter._ensure_token_limit(mock_prompt_text) - - # The prompt doesn't exceed the limit, _ensure_token_limit doesn't truncate it - assert prompt_after_resize == mock_prompt_text - - -def test_long_prompt_is_truncated(mock_boto3_session): - """ - Test that a long prompt is truncated - """ - # Define a long mock prompt and its tokenized version - long_prompt_text = "I am a tokenized prompt of length eight" - long_prompt_tokens = long_prompt_text.split() - - # _ensure_token_limit will truncate the prompt to make it fit into the model's max token limit - truncated_prompt_text = "I am a tokenized prompt of length" - - # Mock the tokenizer to return our predefined tokens - # convert tokens to our predefined truncated text - mock_tokenizer = MagicMock() - mock_tokenizer.tokenize.return_value = long_prompt_tokens - mock_tokenizer.convert_tokens_to_string.return_value = truncated_prompt_text - - # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens - # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) - max_length_generated_text = 3 - total_model_max_length = 10 - - with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): - layer = AmazonBedrockChatGenerator( - "anthropic.claude-v2", - generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, - ) - prompt_after_resize = layer.model_adapter._ensure_token_limit(long_prompt_text) +class TestAmazonBedrockChatGeneratorInference: - # The prompt exceeds the limit, _ensure_token_limit truncates it - assert prompt_after_resize == truncated_prompt_text - - -def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): - """ - Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False - """ - messages = [ChatMessage.from_user("What is the biggest city in United States?")] - - # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) - max_length_generated_text = 3 - total_model_max_length = 10 - - with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()): - generator = AmazonBedrockChatGenerator( - model="anthropic.claude-v2", - truncate=False, - generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, - ) - - # Mock the _ensure_token_limit method to track if it is called - with patch.object( - generator.model_adapter, "_ensure_token_limit", wraps=generator.model_adapter._ensure_token_limit - ) as mock_ensure_token_limit: - # Mock the model adapter to avoid actual invocation - generator.model_adapter.prepare_body = MagicMock(return_value={}) - generator.client = MagicMock() - generator.client.invoke_model = MagicMock( - return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))} - ) - - generator.model_adapter.get_responses = MagicMock( - return_value=[ - ChatMessage.from_assistant( - content="Some text", - meta={ - "model": "claude-3-sonnet-20240229", - "index": 0, - "finish_reason": "end_turn", - "usage": {"prompt_tokens": 16, "completion_tokens": 55}, - }, - ) - ] - ) - # Invoke the generator - generator.run(messages=messages) - - # Ensure _ensure_token_limit was not called - mock_ensure_token_limit.assert_not_called() - - # Check the prompt passed to prepare_body - generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[], stream=False) - - -@pytest.mark.parametrize( - "model, expected_model_adapter", - [ - ("anthropic.claude-v1", AnthropicClaudeChatAdapter), - ("anthropic.claude-v2", AnthropicClaudeChatAdapter), - ("eu.anthropic.claude-v1", AnthropicClaudeChatAdapter), # cross-region inference - ("us.anthropic.claude-v2", AnthropicClaudeChatAdapter), # cross-region inference - ("anthropic.claude-instant-v1", AnthropicClaudeChatAdapter), - ("anthropic.claude-super-v5", AnthropicClaudeChatAdapter), # artificial - ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), - ("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), - ("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial - ("us.meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference - ("eu.meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference - ("de.meta.llama2-130b-v5", MetaLlama2ChatAdapter), # cross-region inference - ("unknown_model", None), - ], -) -def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelChatAdapter]]): - """ - Test that the correct model adapter is returned for a given model - """ - model_adapter = AmazonBedrockChatGenerator.get_model_adapter(model=model) - assert model_adapter == expected_model_adapter - - -class TestAnthropicClaudeAdapter: - def test_prepare_body_with_default_params(self) -> None: - layer = AnthropicClaudeChatAdapter(truncate=True, generation_kwargs={}) - prompt = "Hello, how are you?" - expected_body = { - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": 512, - "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], - } - - body = layer.prepare_body([ChatMessage.from_user(prompt)]) - - assert body == expected_body - - def test_prepare_body_with_custom_inference_params(self) -> None: - layer = AnthropicClaudeChatAdapter( - truncate=True, generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4} - ) - prompt = "Hello, how are you?" - expected_body = { - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": 512, - "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], - "stop_sequences": ["CUSTOM_STOP"], - "temperature": 0.7, - "top_k": 5, - "top_p": 0.8, - } - - body = layer.prepare_body( - [ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69, stop_sequences=["CUSTOM_STOP"] - ) - - assert body == expected_body - - @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) + @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.integration - def test_tools_use(self, model_name): - """ - Test function calling with AWS Bedrock Anthropic adapter - """ - # See https://docs.anthropic.com/en/docs/tool-use for more information - tools = [ - { - "name": "top_song", - "description": "Get the most popular song played on a radio station.", - "input_schema": { - "type": "object", - "properties": { - "sign": { - "type": "string", - "description": "The call sign for the radio station for which you want the most popular" - " song. Example calls signs are WZPZ and WKRP.", - } - }, - "required": ["sign"], - }, - } - ] - messages = [] - messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?")) + def test_default_inference_params(self, model_name, chat_messages): client = AmazonBedrockChatGenerator(model=model_name) - response = client.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": {"type": "any"}}) + response = client.run(chat_messages) + + assert "replies" in response, "Response does not contain 'replies' key" replies = response["replies"] assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" @@ -366,102 +165,32 @@ def test_tools_use(self, model_name): assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" assert first_reply.content, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "top_song" in first_reply.content.lower(), "First reply does not contain top_song" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" - fc_response = json.loads(first_reply.content) - assert "name" in fc_response, "First reply does not contain name of the tool" - assert "input" in fc_response, "First reply does not contain input of the tool" - - -class TestMistralAdapter: - def test_prepare_body_with_default_params(self) -> None: - layer = MistralChatAdapter(truncate=True, generation_kwargs={}) - prompt = "Hello, how are you?" - expected_body = { - "max_tokens": 512, - "prompt": "[INST] Hello, how are you? [/INST]", - } - body = layer.prepare_body([ChatMessage.from_user(prompt)]) + if first_reply.meta and "usage" in first_reply.meta: + assert "prompt_tokens" in first_reply.meta["usage"] + assert "completion_tokens" in first_reply.meta["usage"] - assert body == expected_body + @pytest.mark.parametrize("model_name", MODELS_TO_TEST) + @pytest.mark.integration + def test_default_inference_with_streaming(self, model_name, chat_messages): + streaming_callback_called = False + paris_found_in_response = False - def test_prepare_body_with_custom_inference_params(self) -> None: - layer = MistralChatAdapter(truncate=True, generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) - prompt = "Hello, how are you?" - expected_body = { - "prompt": "[INST] Hello, how are you? [/INST]", - "max_tokens": 512, - "temperature": 0.7, - "top_p": 0.8, - } + def streaming_callback(chunk: StreamingChunk): + nonlocal streaming_callback_called, paris_found_in_response + streaming_callback_called = True + assert isinstance(chunk, StreamingChunk) + assert chunk.content is not None + if not paris_found_in_response: + paris_found_in_response = "paris" in chunk.content.lower() - body = layer.prepare_body([ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69) - - assert body == expected_body - - def test_mistral_chat_template_correct_order(self): - layer = MistralChatAdapter(truncate=True, generation_kwargs={}) - layer.prepare_body([ChatMessage.from_user("A"), ChatMessage.from_assistant("B"), ChatMessage.from_user("C")]) - layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_user("B"), ChatMessage.from_assistant("C")]) - - def test_mistral_chat_template_incorrect_order(self): - layer = MistralChatAdapter(truncate=True, generation_kwargs={}) - try: - layer.prepare_body([ChatMessage.from_assistant("B"), ChatMessage.from_assistant("C")]) - msg = "Expected TemplateError" - raise AssertionError(msg) - except Exception as e: - assert "Conversation roles must alternate user/assistant/" in str(e) - - try: - layer.prepare_body([ChatMessage.from_user("A"), ChatMessage.from_user("B")]) - msg = "Expected TemplateError" - raise AssertionError(msg) - except Exception as e: - assert "Conversation roles must alternate user/assistant/" in str(e) - - try: - layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_system("B")]) - msg = "Expected TemplateError" - raise AssertionError(msg) - except Exception as e: - assert "Conversation roles must alternate user/assistant/" in str(e) - - def test_use_mistral_adapter_without_hf_token(self, monkeypatch, caplog) -> None: - monkeypatch.delenv("HF_TOKEN", raising=False) - with ( - patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, - patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), - caplog.at_level(logging.WARNING), - ): - MistralChatAdapter(truncate=True, generation_kwargs={}) - mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf") - assert "no HF_TOKEN was found" in caplog.text - - def test_use_mistral_adapter_with_hf_token(self, monkeypatch) -> None: - monkeypatch.setenv("HF_TOKEN", "test") - with ( - patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, - patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), - ): - MistralChatAdapter(truncate=True, generation_kwargs={}) - mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1") - - @pytest.mark.skipif( - not os.environ.get("HF_API_TOKEN", None), - reason=( - "To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also " - "have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`" - ), - ) - @pytest.mark.parametrize("model_name", MISTRAL_MODELS) - @pytest.mark.integration - def test_default_inference_params(self, model_name, chat_messages): - client = AmazonBedrockChatGenerator(model=model_name) + client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback) response = client.run(chat_messages) - assert "replies" in response, "Response does not contain 'replies' key" + assert streaming_callback_called, "Streaming callback was not called" + assert paris_found_in_response, "The streaming callback response did not contain 'paris'" replies = response["replies"] assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" @@ -473,77 +202,44 @@ def test_default_inference_params(self, model_name, chat_messages): assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" - -@pytest.fixture -def chat_messages(): - messages = [ - ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), - ChatMessage.from_user("What's the capital of France?"), - ] - return messages - - -class TestMetaLlama2ChatAdapter: - @pytest.mark.integration - def test_prepare_body_with_default_params(self) -> None: - # leave this test as integration because we really need only tokenizer from HF - # that way we can ensure prompt chat message formatting - layer = MetaLlama2ChatAdapter(truncate=True, generation_kwargs={}) - prompt = "Hello, how are you?" - expected_body = {"prompt": "[INST] Hello, how are you? [/INST]", "max_gen_len": 512} - - body = layer.prepare_body([ChatMessage.from_user(prompt)]) - - assert body == expected_body - + @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) @pytest.mark.integration - def test_prepare_body_with_custom_inference_params(self) -> None: - # leave this test as integration because we really need only tokenizer from HF - # that way we can ensure prompt chat message formatting - layer = MetaLlama2ChatAdapter( - truncate=True, - generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]}, - ) - prompt = "Hello, how are you?" - - # expected body is different because stop_sequences and top_k are not supported by MetaLlama2 - expected_body = { - "prompt": "[INST] Hello, how are you? [/INST]", - "max_gen_len": 69, - "temperature": 0.7, - "top_p": 0.8, + def test_tools_use(self, model_name): + """ + Test function calling with AWS Bedrock Anthropic adapter + """ + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html + tool_config = { + "tools": [ + { + "toolSpec": { + "name": "top_song", + "description": "Get the most popular song played on a radio station.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "sign": { + "type": "string", + "description": "The call sign for the radio station " + "for which you want the most popular song. " + "Example calls signs are WZPZ and WKRP.", + } + }, + "required": ["sign"], + } + }, + } + } + ], + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + "toolChoice": {"auto": {}}, } - body = layer.prepare_body( - [ChatMessage.from_user(prompt)], - temperature=0.7, - top_p=0.8, - top_k=5, - max_gen_len=69, - stop_sequences=["CUSTOM_STOP"], - ) - - assert body == expected_body - - @pytest.mark.integration - def test_get_responses(self) -> None: - adapter = MetaLlama2ChatAdapter(truncate=True, generation_kwargs={}) - response_body = {"generation": "This is a single response."} - expected_response = "This is a single response." - response_message = adapter.get_responses(response_body) - # assert that the type of each item in the list is a ChatMessage - for message in response_message: - assert isinstance(message, ChatMessage) - - assert response_message == [ChatMessage.from_assistant(expected_response)] - - @pytest.mark.parametrize("model_name", MODELS_TO_TEST) - @pytest.mark.integration - def test_default_inference_params(self, model_name, chat_messages): + messages = [] + messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?")) client = AmazonBedrockChatGenerator(model=model_name) - response = client.run(chat_messages) - - assert "replies" in response, "Response does not contain 'replies' key" + response = client.run(messages=messages, generation_kwargs={"toolConfig": tool_config}) replies = response["replies"] assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" @@ -552,32 +248,70 @@ def test_default_inference_params(self, model_name, chat_messages): assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" assert first_reply.content, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" - if first_reply.meta and "usage" in first_reply.meta: - assert "prompt_tokens" in first_reply.meta["usage"] - assert "completion_tokens" in first_reply.meta["usage"] - - @pytest.mark.parametrize("model_name", MODELS_TO_TEST) + # Some models return thinking message as first and the second one as the tool call + if len(replies) > 1: + second_reply = replies[1] + assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" + assert second_reply.content, "Second reply has no content" + assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" + tool_call = json.loads(second_reply.content) + assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" + assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" + assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" + else: + # case where the model returns the tool call as the first message + # double check that the tool call is correct + tool_call = json.loads(first_reply.content) + assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" + assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" + assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" + + @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) @pytest.mark.integration - def test_default_inference_with_streaming(self, model_name, chat_messages): - streaming_callback_called = False - paris_found_in_response = False - - def streaming_callback(chunk: StreamingChunk): - nonlocal streaming_callback_called, paris_found_in_response - streaming_callback_called = True - assert isinstance(chunk, StreamingChunk) - assert chunk.content is not None - if not paris_found_in_response: - paris_found_in_response = "paris" in chunk.content.lower() - - client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback) - response = client.run(chat_messages) + def test_tools_use_with_streaming(self, model_name): + """ + Test function calling with AWS Bedrock Anthropic adapter + """ + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html + tool_config = { + "tools": [ + { + "toolSpec": { + "name": "top_song", + "description": "Get the most popular song played on a radio station.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "sign": { + "type": "string", + "description": "The call sign for the radio station " + "for which you want the most popular song. Example " + "calls signs are WZPZ and WKRP.", + } + }, + "required": ["sign"], + } + }, + } + } + ], + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + "toolChoice": {"auto": {}}, + } - assert streaming_callback_called, "Streaming callback was not called" - assert paris_found_in_response, "The streaming callback response did not contain 'paris'" + messages = [] + messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?")) + client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=print_streaming_chunk) + response = client.run(messages=messages, generation_kwargs={"toolConfig": tool_config}) replies = response["replies"] assert isinstance(replies, list), "Replies is not a list" assert len(replies) > 0, "No replies received" @@ -586,5 +320,139 @@ def streaming_callback(chunk: StreamingChunk): assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" assert first_reply.content, "First reply has no content" assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" + + # Some models return thinking message as first and the second one as the tool call + if len(replies) > 1: + second_reply = replies[1] + assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" + assert second_reply.content, "Second reply has no content" + assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" + tool_call = json.loads(second_reply.content) + assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" + assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" + assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" + else: + # case where the model returns the tool call as the first message + # double check that the tool call is correct + tool_call = json.loads(first_reply.content) + assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" + assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" + assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" + + def test_extract_replies_from_response(self, mock_boto3_session): + """ + Test that extract_replies_from_response correctly processes both text and tool use responses + """ + generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") + + # Test case 1: Simple text response + text_response = { + "output": {"message": {"role": "assistant", "content": [{"text": "This is a test response"}]}}, + "stopReason": "complete", + "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, + } + + replies = generator.extract_replies_from_response(text_response) + assert len(replies) == 1 + assert replies[0].content == "This is a test response" + assert replies[0].role == ChatRole.ASSISTANT + assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert replies[0].meta["finish_reason"] == "complete" + assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + + # Test case 2: Tool use response + tool_response = { + "output": { + "message": { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"key": "value"}}}], + } + }, + "stopReason": "tool_call", + "usage": {"inputTokens": 15, "outputTokens": 25, "totalTokens": 40}, + } + + replies = generator.extract_replies_from_response(tool_response) + assert len(replies) == 1 + tool_content = json.loads(replies[0].content) + assert tool_content["toolUseId"] == "123" + assert tool_content["name"] == "test_tool" + assert tool_content["input"] == {"key": "value"} + assert replies[0].meta["finish_reason"] == "tool_call" + assert replies[0].meta["usage"] == {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40} + + # Test case 3: Mixed content response + mixed_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + {"text": "Let me help you with that. I'll use the search tool to find the answer."}, + {"toolUse": {"toolUseId": "456", "name": "search_tool", "input": {"query": "test"}}}, + ], + } + }, + "stopReason": "complete", + "usage": {"inputTokens": 25, "outputTokens": 35, "totalTokens": 60}, + } + + replies = generator.extract_replies_from_response(mixed_response) + assert len(replies) == 2 + assert replies[0].content == "Let me help you with that. I'll use the search tool to find the answer." + tool_content = json.loads(replies[1].content) + assert tool_content["toolUseId"] == "456" + assert tool_content["name"] == "search_tool" + assert tool_content["input"] == {"query": "test"} + + def test_process_streaming_response(self, mock_boto3_session): + """ + Test that process_streaming_response correctly handles streaming events and accumulates responses + """ + generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") + + streaming_chunks = [] + + def test_callback(chunk: StreamingChunk): + streaming_chunks.append(chunk) + + # Simulate a stream of events for both text and tool use + events = [ + {"contentBlockStart": {"start": {"text": ""}}}, + {"contentBlockDelta": {"delta": {"text": "Let me "}}}, + {"contentBlockDelta": {"delta": {"text": "help you."}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "search_tool"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"query":'}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '"test"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "complete"}}, + {"metadata": {"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}}}, + ] + + replies = generator.process_streaming_response(events, test_callback) + + # Verify streaming chunks were received for text content + assert len(streaming_chunks) == 2 + assert streaming_chunks[0].content == "Let me " + assert streaming_chunks[1].content == "help you." + + # Verify final replies + assert len(replies) == 2 + # Check text reply + assert replies[0].content == "Let me help you." + assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert replies[0].meta["finish_reason"] == "complete" + assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + + # Check tool use reply + tool_content = json.loads(replies[1].content) + assert tool_content["toolUseId"] == "123" + assert tool_content["name"] == "search_tool" + assert tool_content["input"] == {"query": "test"}