From 1c9b710083b296085184f5b167d4198c6fd665c0 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 21 Nov 2024 16:11:58 +0100 Subject: [PATCH 01/17] Initial commit --- .../amazon_bedrock/chat/adapters.py | 569 ------------------ .../amazon_bedrock/chat/chat_generator.py | 231 +++---- .../tests/test_chat_generator.py | 424 +------------ 3 files changed, 138 insertions(+), 1086 deletions(-) delete mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py deleted file mode 100644 index cbb5ee370..000000000 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ /dev/null @@ -1,569 +0,0 @@ -import json -import logging -import os -from abc import ABC, abstractmethod -from typing import Any, Callable, ClassVar, Dict, List, Optional - -from botocore.eventstream import EventStream -from haystack.components.generators.openai_utils import _convert_message_to_openai_format -from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk -from transformers import AutoTokenizer, PreTrainedTokenizer - -from haystack_integrations.components.generators.amazon_bedrock.handlers import DefaultPromptHandler - -logger = logging.getLogger(__name__) - - -class BedrockModelChatAdapter(ABC): - """ - Base class for Amazon Bedrock chat model adapters. - - Each subclass of this class is designed to address the unique specificities of a particular chat LLM it adapts, - focusing on preparing the requests and extracting the responses from the Amazon Bedrock hosted chat LLMs. - """ - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: - """ - Initializes the chat adapter with the truncate parameter and generation kwargs. - """ - self.generation_kwargs = generation_kwargs - self.truncate = truncate - - @abstractmethod - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Amazon Bedrock request. - Subclasses should override this method to package the chat messages into the request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - :returns: The prepared body. - """ - - def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the responses from the Amazon Bedrock response. - - :param response_body: The response body. - :returns: The extracted responses. - """ - return self._extract_messages_from_response(response_body) - - def get_stream_responses( - self, stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] - ) -> List[ChatMessage]: - streaming_chunks: List[StreamingChunk] = [] - last_decoded_chunk: Dict[str, Any] = {} - for event in stream: - chunk = event.get("chunk") - if chunk: - last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) - streaming_chunk = self._build_streaming_chunk(last_decoded_chunk) - streaming_callback(streaming_chunk) # callback the stream handler with StreamingChunk - streaming_chunks.append(streaming_chunk) - responses = ["".join(chunk.content for chunk in streaming_chunks).lstrip()] - return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses] - - @staticmethod - def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any], allowed_params: List[str]) -> None: - """ - Updates target_dict with values from updates_dict. Merges lists instead of overriding them. - - :param target_dict: The dictionary to update. - :param updates_dict: The dictionary with updates. - :param allowed_params: The list of allowed params to use. - """ - for key, value in updates_dict.items(): - if key not in allowed_params: - logger.warning(f"Parameter '{key}' is not allowed and will be ignored.") - continue - if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list): - # Merge lists and remove duplicates - target_dict[key] = sorted(set(target_dict[key] + value)) - else: - # Override the value in target_dict - target_dict[key] = value - - def _get_params( - self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any], allowed_params: List[str] - ) -> Dict[str, Any]: - """ - Merges params from inference_kwargs with the default params and self.generation_kwargs. - Uses a helper function to merge lists or override values as necessary. - - :param inference_kwargs: The inference kwargs to merge. - :param default_params: The default params to start with. - :param allowed_params: The list of allowed params to use. - :returns: The merged params. - """ - # Start with a copy of default_params - kwargs = default_params.copy() - - # Update the default params with self.generation_kwargs and finally inference_kwargs - self._update_params(kwargs, self.generation_kwargs, allowed_params) - self._update_params(kwargs, inference_kwargs, allowed_params) - - return kwargs - - def _ensure_token_limit(self, prompt: str) -> str: - """ - Ensures that the prompt is within the token limit for the model. - :param prompt: The prompt to check. - :returns: The resized prompt. - """ - resize_info = self.check_prompt(prompt) - if resize_info["prompt_length"] != resize_info["new_prompt_length"]: - logger.warning( - "The prompt was truncated from %s tokens to %s tokens so that the prompt length and " - "the answer length (%s tokens) fit within the model's max token limit (%s tokens). " - "Shorten the prompt or it will be cut off.", - resize_info["prompt_length"], - max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore - resize_info["max_length"], - resize_info["model_max_length"], - ) - return str(resize_info["resized_prompt"]) - - @abstractmethod - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - """ - - @abstractmethod - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :returns: The extracted ChatMessage list. - """ - - @abstractmethod - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - - -class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): - """ - Model adapter for the Anthropic Claude chat model. - """ - - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - ALLOWED_PARAMS: ClassVar[List[str]] = [ - "anthropic_version", - "max_tokens", - "stop_sequences", - "temperature", - "top_p", - "top_k", - "system", - "tools", - "tool_choice", - ] - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): - """ - Initializes the Anthropic Claude chat adapter. - - :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. - :param generation_kwargs: The generation kwargs. - """ - super().__init__(truncate, generation_kwargs) - - # We pop the model_max_length as it is not sent to the model - # but used to truncate the prompt if needed - # Anthropic Claude has a limit of at least 100000 tokens - # https://docs.anthropic.com/claude/reference/input-and-output-sizes - model_max_length = self.generation_kwargs.pop("model_max_length", 100000) - - # Truncate prompt if prompt tokens > model_max_length-max_length - # (max_length is the length of the generated text) - # TODO use Anthropic tokenizer to get the precise prompt length - # See https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#token-counting - self.prompt_handler = DefaultPromptHandler( - tokenizer="gpt2", - model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_tokens") or 512, - ) - - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Anthropic Claude request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - :returns: The prepared body. - """ - default_params = { - "anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31", - "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required - } - - # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it - stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) - if stop_sequences: - inference_kwargs["stop_sequences"] = stop_sequences - # pop stream kwarg from inference_kwargs as Anthropic does not support it (if provided) - inference_kwargs.pop("stream", None) - params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) - body = {**self.prepare_chat_messages(messages=messages), **params} - return body - - def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: - """ - Prepares the chat messages for the Anthropic Claude request. - - :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a dictionary. - """ - body: Dict[str, Any] = {} - system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None - body["messages"] = [ - self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT) - ] - if system: - body["system"] = system - # Ensure token limit for each message in the body - if self.truncate: - for message in body["messages"]: - for content in message["content"]: - content["text"] = self._ensure_token_limit(content["text"]) - return body - - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - """ - return self.prompt_handler(prompt) - - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :return: The extracted ChatMessage list. - """ - messages: List[ChatMessage] = [] - if response_body.get("type") == "message": - if response_body.get("stop_reason") == "tool_use": # If `tool_use` we only keep the tool_use content - for content in response_body["content"]: - if content.get("type") == "tool_use": - meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} - json_answer = json.dumps(content) - messages.append(ChatMessage.from_assistant(json_answer, meta=meta)) - else: # For other stop_reason, return all text content - for content in response_body["content"]: - if content.get("type") == "text": - meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} - messages.append(ChatMessage.from_assistant(content["text"], meta=meta)) - - return messages - - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta": - return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk) - return StreamingChunk(content="", meta=chunk) - - def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]: - """ - Convert a ChatMessage to a dictionary with the content and role fields. - :param m: The ChatMessage to convert. - :return: The dictionary with the content and role fields. - """ - return {"content": [{"type": "text", "text": m.content}], "role": m.role.value} - - -class MistralChatAdapter(BedrockModelChatAdapter): - """ - Model adapter for the Mistral chat model. - """ - - chat_template = """ - {% if messages[0]['role'] == 'system' %} - {% set loop_messages = messages[1:] %} - {% set system_message = messages[0]['content'] %} - {% else %} - {% set loop_messages = messages %} - {% set system_message = false %} - {% endif %} - {{bos_token}} - {% for message in loop_messages %} - {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} - {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} - {% endif %} - {% if loop.index0 == 0 and system_message != false %} - {% set content = system_message + '\n' + message['content'] %} - {% else %} - {% set content = message['content'] %} - {% endif %} - {% if message['role'] == 'user' %} - {{ '[INST] ' + content.strip() + ' [/INST]' }} - {% elif message['role'] == 'assistant' %} - {{ content.strip() + eos_token }} - {% endif %} - {% endfor %} - """ - chat_template = "".join(line.strip() for line in chat_template.splitlines()) - - # the above template was designed to match https://docs.mistral.ai/models/#chat-template - # and to support system messages, otherwise we could use the default mistral chat template - # available on HF infrastructure - - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - ALLOWED_PARAMS: ClassVar[List[str]] = [ - "max_tokens", - "safe_prompt", - "random_seed", - "temperature", - "top_p", - ] - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): - """ - Initializes the Mistral chat adapter. - :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. - :param generation_kwargs: The generation kwargs. - """ - super().__init__(truncate, generation_kwargs) - - # We pop the model_max_length as it is not sent to the model - # but used to truncate the prompt if needed - # Mistral has a limit of at least 32000 tokens - model_max_length = self.generation_kwargs.pop("model_max_length", 32000) - - # Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer - # a) we should get good estimates for the prompt length - # b) we can use apply_chat_template with the template above to delineate ChatMessages - # Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model. - tokenizer: PreTrainedTokenizer - if os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN"): - tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") - else: - tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") - logger.warning( - "Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for " - "estimating the prompt length because no HF_TOKEN was found. Using " - "NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var " - "HF_TOKEN containing a Hugging Face token and make sure you have access to the model." - ) - - self.prompt_handler = DefaultPromptHandler( - tokenizer=tokenizer, - model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_tokens") or 512, - ) - - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Mistral request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - :returns: The prepared body. - """ - default_params = { - "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required - } - # replace stop_words from inference_kwargs with stop, as this is Mistral specific parameter - stop_words = inference_kwargs.pop("stop_words", []) - if stop_words: - inference_kwargs["stop"] = stop_words - - # pop stream kwarg from inference_kwargs as Mistral does not support it (if provided) - inference_kwargs.pop("stream", None) - - params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) - body = {"prompt": self.prepare_chat_messages(messages=messages), **params} - return body - - def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: - """ - Prepares the chat messages for the Mistral request. - - :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a string. - """ - # it would be great to use the default mistral chat template, but it doesn't support system messages - # the class variable defined chat_template is a workaround to support system messages - # default is https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json - # but we'll use our custom chat template - prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=[_convert_message_to_openai_format(m) for m in messages], - tokenize=False, - chat_template=self.chat_template, - ) - if self.truncate: - prepared_prompt = self._ensure_token_limit(prepared_prompt) - return prepared_prompt - - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - """ - return self.prompt_handler(prompt) - - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :return: The extracted ChatMessage list. - """ - messages: List[ChatMessage] = [] - responses = response_body.get("outputs", []) - for response in responses: - meta = {k: v for k, v in response.items() if k not in ["text"]} - messages.append(ChatMessage.from_assistant(response["text"], meta=meta)) - return messages - - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - response_chunk = chunk.get("outputs", []) - if response_chunk: - return StreamingChunk(content=response_chunk[0].get("text", ""), meta=chunk) - return StreamingChunk(content="", meta=chunk) - - -class MetaLlama2ChatAdapter(BedrockModelChatAdapter): - """ - Model adapter for the Meta Llama 2 models. - """ - - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html - ALLOWED_PARAMS: ClassVar[List[str]] = ["max_gen_len", "temperature", "top_p"] - - chat_template = ( - "{% if messages[0]['role'] == 'system' %}" - "{% set loop_messages = messages[1:] %}" - "{% set system_message = messages[0]['content'] %}" - "{% else %}" - "{% set loop_messages = messages %}" - "{% set system_message = false %}" - "{% endif %}" - "{% for message in loop_messages %}" - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if loop.index0 == 0 and system_message != false %}" - "{% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %}" - "{% else %}" - "{% set content = message['content'] %}" - "{% endif %}" - "{% if message['role'] == 'user' %}" - "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" - "{% elif message['role'] == 'assistant' %}" - "{{ ' ' + content.strip() + ' ' + eos_token }}" - "{% endif %}" - "{% endfor %}" - ) - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: - """ - Initializes the Meta Llama 2 chat adapter. - :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. - :param generation_kwargs: The generation kwargs. - """ - super().__init__(truncate, generation_kwargs) - # We pop the model_max_length as it is not sent to the model - # but used to truncate the prompt if needed - # Llama 2 has context window size of 4096 tokens - # with some exceptions when the context window has been extended - model_max_length = self.generation_kwargs.pop("model_max_length", 4096) - - # Use `google/flan-t5-base` as it's also BPE sentencepiece tokenizer just like llama 2 - # a) we should get good estimates for the prompt length (empirically close to llama 2) - # b) we can use apply_chat_template with the template above to delineate ChatMessages - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") - tokenizer.bos_token = "" - tokenizer.eos_token = "" - tokenizer.unk_token = "" - self.prompt_handler = DefaultPromptHandler( - tokenizer=tokenizer, - model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_gen_len") or 512, - ) - - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Meta Llama 2 request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - """ - default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512} - - # no support for stop words in Meta Llama 2 - params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) - body = {"prompt": self.prepare_chat_messages(messages=messages), **params} - return body - - def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: - """ - Prepares the chat messages for the Meta Llama 2 request. - - :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a string ready for the model. - """ - prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=messages, tokenize=False, chat_template=self.chat_template - ) - - if self.truncate: - prepared_prompt = self._ensure_token_limit(prepared_prompt) - return prepared_prompt - - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - - """ - return self.prompt_handler(prompt) - - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :return: The extracted ChatMessage list. - """ - message_tag = "generation" - metadata = {k: v for (k, v) in response_body.items() if k != message_tag} - return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] - - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - return StreamingChunk(content=chunk.get("generation", ""), meta=chunk) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 183198bce..bfae0317a 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -1,12 +1,10 @@ -import json import logging -import re -from typing import Any, Callable, ClassVar, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional from botocore.config import Config from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable @@ -16,8 +14,6 @@ ) from haystack_integrations.common.amazon_bedrock.utils import get_aws_session -from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter, MistralChatAdapter - logger = logging.getLogger(__name__) @@ -58,12 +54,6 @@ class AmazonBedrockChatGenerator: supports Amazon Bedrock. """ - SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { - r"([a-z]{2}\.)?anthropic.claude.*": AnthropicClaudeChatAdapter, - r"([a-z]{2}\.)?meta.llama2.*": MetaLlama2ChatAdapter, - r"([a-z]{2}\.)?mistral.*": MistralChatAdapter, - } - def __init__( self, model: str, @@ -132,14 +122,6 @@ def __init__( self.truncate = truncate self.boto3_config = boto3_config - # get the model adapter for the given model - model_adapter_cls = self.get_model_adapter(model=model) - if not model_adapter_cls: - msg = f"AmazonBedrockGenerator doesn't support the model {model}." - raise AmazonBedrockConfigurationError(msg) - self.model_adapter = model_adapter_cls(self.truncate, generation_kwargs or {}) - - # create the AWS session and client def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -162,89 +144,9 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: ) raise AmazonBedrockConfigurationError(msg) from exception - @component.output_types(replies=List[ChatMessage]) - def run( - self, - messages: List[ChatMessage], - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - generation_kwargs: Optional[Dict[str, Any]] = None, - ): - """ - Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. - - :param messages: The messages to generate a response to. - :param streaming_callback: - A callback function that is called when a new token is received from the stream. - :param generation_kwargs: Additional generation keyword arguments passed to the model. - :returns: A dictionary with the following keys: - - `replies`: The generated List of `ChatMessage` objects. - """ - generation_kwargs = generation_kwargs or {} - generation_kwargs = generation_kwargs.copy() - - streaming_callback = streaming_callback or self.streaming_callback - generation_kwargs["stream"] = streaming_callback is not None - - # check if the prompt is a list of ChatMessage objects - if not ( - isinstance(messages, list) - and len(messages) > 0 - and all(isinstance(message, ChatMessage) for message in messages) - ): - msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." - raise ValueError(msg) - - body = self.model_adapter.prepare_body( - messages=messages, **{"stop_words": self.stop_words, **generation_kwargs} - ) - try: - if streaming_callback: - response = self.client.invoke_model_with_response_stream( - body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" - ) - response_stream = response["body"] - replies = self.model_adapter.get_stream_responses( - stream=response_stream, streaming_callback=streaming_callback - ) - else: - response = self.client.invoke_model( - body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" - ) - response_body = json.loads(response.get("body").read().decode("utf-8")) - replies = self.model_adapter.get_responses(response_body=response_body) - except ClientError as exception: - msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" - raise AmazonBedrockInferenceError(msg) from exception - - # rename the meta key to be inline with OpenAI meta output keys - for response in replies: - if response.meta: - if "usage" in response.meta: - if "input_tokens" in response.meta["usage"]: - response.meta["usage"]["prompt_tokens"] = response.meta["usage"].pop("input_tokens") - if "output_tokens" in response.meta["usage"]: - response.meta["usage"]["completion_tokens"] = response.meta["usage"].pop("output_tokens") - else: - response.meta["usage"] = {} - if "prompt_token_count" in response.meta: - response.meta["usage"]["prompt_tokens"] = response.meta.pop("prompt_token_count") - if "generation_token_count" in response.meta: - response.meta["usage"]["completion_tokens"] = response.meta.pop("generation_token_count") - - return {"replies": replies} - - @classmethod - def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]]: - """ - Returns the model adapter for the given model. - - :param model: The model to get the adapter for. - :returns: The model adapter for the given model, or None if the model is not supported. - """ - for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): - if re.fullmatch(pattern, model): - return adapter - return None + self.generation_kwargs = generation_kwargs or {} + self.stop_words = stop_words or [] + self.streaming_callback = streaming_callback def to_dict(self) -> Dict[str, Any]: """ @@ -263,7 +165,7 @@ def to_dict(self) -> Dict[str, Any]: aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, stop_words=self.stop_words, - generation_kwargs=self.model_adapter.generation_kwargs, + generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, truncate=self.truncate, boto3_config=self.boto3_config, @@ -274,10 +176,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": """ Deserializes the component from a dictionary. - :param data: - Dictionary to deserialize from. + :param data: Dictionary with serialized data. :returns: - Deserialized component. + Instance of `AmazonBedrockChatGenerator`. """ init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") @@ -288,3 +189,119 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], ) return default_from_dict(cls, data) + + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + generation_kwargs = generation_kwargs or {} + inference_config = self.generation_kwargs.copy() + inference_config.update(generation_kwargs) + + # Prepare system prompts if any + system_prompts: List[Dict[str, Any]] = [] # Type annotation added + if messages and messages[0].is_from(ChatRole.SYSTEM): + system_message = messages[0] + system_prompts = [{"text": system_message.content}] + messages = messages[1:] + + # Prepare messages + messages_list = [] + for msg in messages: + message_dict = {"role": msg.role.value, "content": [{"text": msg.content}]} + messages_list.append(message_dict) + + try: + if streaming_callback or self.streaming_callback: + response = self.client.converse_stream( + modelId=self.model, + messages=messages_list, + system=system_prompts, # Now properly typed + inferenceConfig=inference_config, + ) + response_stream = response.get("stream") + if not response_stream: + msg = "No stream found in the response." + raise AmazonBedrockInferenceError(msg) + callback = streaming_callback or self.streaming_callback + if callback is None: # This should never happen due to the if condition above + msg = "No streaming callback provided" + raise ValueError(msg) + replies = self.process_streaming_response(response_stream, callback) + else: + response = self.client.converse( + modelId=self.model, + messages=messages_list, + system=system_prompts, # Now properly typed + inferenceConfig=inference_config, + ) + replies = self.extract_replies_from_response(response) + except ClientError as exception: + msg = f"Could not generate inference for Amazon Bedrock model {self.model} due: {exception}" + raise AmazonBedrockInferenceError(msg) from exception + + return {"replies": replies} + + def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + replies = [] + if "output" in response_body and "message" in response_body["output"]: + message = response_body["output"]["message"] + if message["role"] == "assistant": + content_blocks = message["content"] + text = "" + for content_block in content_blocks: + if "text" in content_block: + text += content_block["text"] + + # Convert usage format from Bedrock to OpenAI format + usage = response_body.get("usage", {}) + meta = { + "model": self.model, + "index": 0, + "finish_reason": response_body.get("stopReason"), + "usage": { + "prompt_tokens": usage.get("inputTokens", 0), + "completion_tokens": usage.get("outputTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + }, + } + + replies.append(ChatMessage.from_assistant(content=text, meta=meta)) + return replies + + def process_streaming_response( + self, response_stream, streaming_callback: Callable[[StreamingChunk], None] + ) -> List[ChatMessage]: + content = "" + meta = { + "model": self.model, + "index": 0, + } + + for event in response_stream: + # if "messageStart" in event: + # role = event["messageStart"]["role"] + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + delta_text = delta.get("text", "") + if delta_text: + content += delta_text + streaming_chunk = StreamingChunk(content=delta_text, meta=None) + streaming_callback(streaming_chunk) + if "messageStop" in event: + meta["finish_reason"] = event["messageStop"].get("stopReason") + if "metadata" in event: + metadata = event["metadata"] + if "usage" in metadata: + usage = metadata["usage"] + meta["usage"] = { + "prompt_tokens": usage.get("inputTokens", 0), + "completion_tokens": usage.get("outputTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + } + + replies = [ChatMessage.from_assistant(content=content, meta=meta)] + return replies diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 8d6a5c3ee..c349d035e 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,23 +1,11 @@ -import json -import logging -import os -from typing import Any, Dict, Optional, Type -from unittest.mock import MagicMock, patch - import pytest from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator -from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import ( - AnthropicClaudeChatAdapter, - BedrockModelChatAdapter, - MetaLlama2ChatAdapter, - MistralChatAdapter, -) KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" -MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"] +MODELS_TO_TEST = ["anthropic.claude-3-5-sonnet-20240620-v1:0"] MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-haiku-20240307-v1:0"] MISTRAL_MODELS = [ "mistral.mistral-7b-instruct-v0:2", @@ -26,16 +14,16 @@ ] -@pytest.mark.parametrize( - "boto3_config", - [ - None, - { - "read_timeout": 1000, - }, - ], -) -def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]): +@pytest.fixture +def chat_messages(): + messages = [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + return messages + + +def test_to_dict(mock_boto3_session): """ Test that the to_dict method returns the correct dictionary without aws credentials """ @@ -96,7 +84,6 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any } ) assert generator.model == "anthropic.claude-v2" - assert generator.model_adapter.generation_kwargs == {"temperature": 0.7} assert generator.streaming_callback == print_streaming_chunk assert generator.boto3_config == boto3_config @@ -112,8 +99,6 @@ def test_default_constructor(mock_boto3_session, set_env_variables): assert layer.model == "anthropic.claude-v2" assert layer.truncate is True - assert layer.model_adapter.prompt_handler is not None - assert layer.model_adapter.prompt_handler.model_max_length == 100000 # assert mocked boto3 client called exactly once mock_boto3_session.assert_called_once() @@ -135,9 +120,7 @@ def test_constructor_with_generation_kwargs(mock_boto3_session): generation_kwargs = {"temperature": 0.7} layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", generation_kwargs=generation_kwargs) - assert "temperature" in layer.model_adapter.generation_kwargs - assert layer.model_adapter.generation_kwargs["temperature"] == 0.7 - assert layer.model_adapter.truncate is True + assert layer.generation_kwargs == generation_kwargs def test_constructor_with_truncate(mock_boto3_session): @@ -145,7 +128,7 @@ def test_constructor_with_truncate(mock_boto3_session): Test that truncate param is correctly set in the model constructor """ layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", truncate=False) - assert layer.model_adapter.truncate is False + assert layer.truncate is False def test_constructor_with_empty_model(): @@ -156,386 +139,7 @@ def test_constructor_with_empty_model(): AmazonBedrockChatGenerator(model="") -def test_short_prompt_is_not_truncated(mock_boto3_session): - """ - Test that a short prompt is not truncated - """ - # Define a short mock prompt and its tokenized version - mock_prompt_text = "I am a tokenized prompt" - mock_prompt_tokens = mock_prompt_text.split() - - # Mock the tokenizer so it returns our predefined tokens - mock_tokenizer = MagicMock() - mock_tokenizer.tokenize.return_value = mock_prompt_tokens - - # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens - # Since our mock prompt is 5 tokens long, it doesn't exceed the - # total limit (5 prompt tokens + 3 generated tokens < 10 tokens) - max_length_generated_text = 3 - total_model_max_length = 10 - - with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): - layer = AmazonBedrockChatGenerator( - "anthropic.claude-v2", - generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, - ) - prompt_after_resize = layer.model_adapter._ensure_token_limit(mock_prompt_text) - - # The prompt doesn't exceed the limit, _ensure_token_limit doesn't truncate it - assert prompt_after_resize == mock_prompt_text - - -def test_long_prompt_is_truncated(mock_boto3_session): - """ - Test that a long prompt is truncated - """ - # Define a long mock prompt and its tokenized version - long_prompt_text = "I am a tokenized prompt of length eight" - long_prompt_tokens = long_prompt_text.split() - - # _ensure_token_limit will truncate the prompt to make it fit into the model's max token limit - truncated_prompt_text = "I am a tokenized prompt of length" - - # Mock the tokenizer to return our predefined tokens - # convert tokens to our predefined truncated text - mock_tokenizer = MagicMock() - mock_tokenizer.tokenize.return_value = long_prompt_tokens - mock_tokenizer.convert_tokens_to_string.return_value = truncated_prompt_text - - # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens - # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) - max_length_generated_text = 3 - total_model_max_length = 10 - - with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): - layer = AmazonBedrockChatGenerator( - "anthropic.claude-v2", - generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, - ) - prompt_after_resize = layer.model_adapter._ensure_token_limit(long_prompt_text) - - # The prompt exceeds the limit, _ensure_token_limit truncates it - assert prompt_after_resize == truncated_prompt_text - - -def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): - """ - Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False - """ - messages = [ChatMessage.from_user("What is the biggest city in United States?")] - - # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) - max_length_generated_text = 3 - total_model_max_length = 10 - - with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()): - generator = AmazonBedrockChatGenerator( - model="anthropic.claude-v2", - truncate=False, - generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text}, - ) - - # Mock the _ensure_token_limit method to track if it is called - with patch.object( - generator.model_adapter, "_ensure_token_limit", wraps=generator.model_adapter._ensure_token_limit - ) as mock_ensure_token_limit: - # Mock the model adapter to avoid actual invocation - generator.model_adapter.prepare_body = MagicMock(return_value={}) - generator.client = MagicMock() - generator.client.invoke_model = MagicMock( - return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))} - ) - - generator.model_adapter.get_responses = MagicMock( - return_value=[ - ChatMessage.from_assistant( - content="Some text", - meta={ - "model": "claude-3-sonnet-20240229", - "index": 0, - "finish_reason": "end_turn", - "usage": {"prompt_tokens": 16, "completion_tokens": 55}, - }, - ) - ] - ) - # Invoke the generator - generator.run(messages=messages) - - # Ensure _ensure_token_limit was not called - mock_ensure_token_limit.assert_not_called() - - # Check the prompt passed to prepare_body - generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[], stream=False) - - -@pytest.mark.parametrize( - "model, expected_model_adapter", - [ - ("anthropic.claude-v1", AnthropicClaudeChatAdapter), - ("anthropic.claude-v2", AnthropicClaudeChatAdapter), - ("eu.anthropic.claude-v1", AnthropicClaudeChatAdapter), # cross-region inference - ("us.anthropic.claude-v2", AnthropicClaudeChatAdapter), # cross-region inference - ("anthropic.claude-instant-v1", AnthropicClaudeChatAdapter), - ("anthropic.claude-super-v5", AnthropicClaudeChatAdapter), # artificial - ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), - ("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), - ("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial - ("us.meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference - ("eu.meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference - ("de.meta.llama2-130b-v5", MetaLlama2ChatAdapter), # cross-region inference - ("unknown_model", None), - ], -) -def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelChatAdapter]]): - """ - Test that the correct model adapter is returned for a given model - """ - model_adapter = AmazonBedrockChatGenerator.get_model_adapter(model=model) - assert model_adapter == expected_model_adapter - - -class TestAnthropicClaudeAdapter: - def test_prepare_body_with_default_params(self) -> None: - layer = AnthropicClaudeChatAdapter(truncate=True, generation_kwargs={}) - prompt = "Hello, how are you?" - expected_body = { - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": 512, - "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], - } - - body = layer.prepare_body([ChatMessage.from_user(prompt)]) - - assert body == expected_body - - def test_prepare_body_with_custom_inference_params(self) -> None: - layer = AnthropicClaudeChatAdapter( - truncate=True, generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4} - ) - prompt = "Hello, how are you?" - expected_body = { - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": 512, - "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], - "stop_sequences": ["CUSTOM_STOP"], - "temperature": 0.7, - "top_k": 5, - "top_p": 0.8, - } - - body = layer.prepare_body( - [ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69, stop_sequences=["CUSTOM_STOP"] - ) - - assert body == expected_body - - @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) - @pytest.mark.integration - def test_tools_use(self, model_name): - """ - Test function calling with AWS Bedrock Anthropic adapter - """ - # See https://docs.anthropic.com/en/docs/tool-use for more information - tools = [ - { - "name": "top_song", - "description": "Get the most popular song played on a radio station.", - "input_schema": { - "type": "object", - "properties": { - "sign": { - "type": "string", - "description": "The call sign for the radio station for which you want the most popular" - " song. Example calls signs are WZPZ and WKRP.", - } - }, - "required": ["sign"], - }, - } - ] - messages = [] - messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?")) - client = AmazonBedrockChatGenerator(model=model_name) - response = client.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": {"type": "any"}}) - replies = response["replies"] - assert isinstance(replies, list), "Replies is not a list" - assert len(replies) > 0, "No replies received" - - first_reply = replies[0] - assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" - assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "top_song" in first_reply.content.lower(), "First reply does not contain top_song" - assert first_reply.meta, "First reply has no metadata" - fc_response = json.loads(first_reply.content) - assert "name" in fc_response, "First reply does not contain name of the tool" - assert "input" in fc_response, "First reply does not contain input of the tool" - - -class TestMistralAdapter: - def test_prepare_body_with_default_params(self) -> None: - layer = MistralChatAdapter(truncate=True, generation_kwargs={}) - prompt = "Hello, how are you?" - expected_body = { - "max_tokens": 512, - "prompt": "[INST] Hello, how are you? [/INST]", - } - - body = layer.prepare_body([ChatMessage.from_user(prompt)]) - - assert body == expected_body - - def test_prepare_body_with_custom_inference_params(self) -> None: - layer = MistralChatAdapter(truncate=True, generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) - prompt = "Hello, how are you?" - expected_body = { - "prompt": "[INST] Hello, how are you? [/INST]", - "max_tokens": 512, - "temperature": 0.7, - "top_p": 0.8, - } - - body = layer.prepare_body([ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69) - - assert body == expected_body - - def test_mistral_chat_template_correct_order(self): - layer = MistralChatAdapter(truncate=True, generation_kwargs={}) - layer.prepare_body([ChatMessage.from_user("A"), ChatMessage.from_assistant("B"), ChatMessage.from_user("C")]) - layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_user("B"), ChatMessage.from_assistant("C")]) - - def test_mistral_chat_template_incorrect_order(self): - layer = MistralChatAdapter(truncate=True, generation_kwargs={}) - try: - layer.prepare_body([ChatMessage.from_assistant("B"), ChatMessage.from_assistant("C")]) - msg = "Expected TemplateError" - raise AssertionError(msg) - except Exception as e: - assert "Conversation roles must alternate user/assistant/" in str(e) - - try: - layer.prepare_body([ChatMessage.from_user("A"), ChatMessage.from_user("B")]) - msg = "Expected TemplateError" - raise AssertionError(msg) - except Exception as e: - assert "Conversation roles must alternate user/assistant/" in str(e) - - try: - layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_system("B")]) - msg = "Expected TemplateError" - raise AssertionError(msg) - except Exception as e: - assert "Conversation roles must alternate user/assistant/" in str(e) - - def test_use_mistral_adapter_without_hf_token(self, monkeypatch, caplog) -> None: - monkeypatch.delenv("HF_TOKEN", raising=False) - with ( - patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, - patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), - caplog.at_level(logging.WARNING), - ): - MistralChatAdapter(truncate=True, generation_kwargs={}) - mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf") - assert "no HF_TOKEN was found" in caplog.text - - def test_use_mistral_adapter_with_hf_token(self, monkeypatch) -> None: - monkeypatch.setenv("HF_TOKEN", "test") - with ( - patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, - patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), - ): - MistralChatAdapter(truncate=True, generation_kwargs={}) - mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1") - - @pytest.mark.skipif( - not os.environ.get("HF_API_TOKEN", None), - reason=( - "To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also " - "have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`" - ), - ) - @pytest.mark.parametrize("model_name", MISTRAL_MODELS) - @pytest.mark.integration - def test_default_inference_params(self, model_name, chat_messages): - client = AmazonBedrockChatGenerator(model=model_name) - response = client.run(chat_messages) - - assert "replies" in response, "Response does not contain 'replies' key" - replies = response["replies"] - assert isinstance(replies, list), "Replies is not a list" - assert len(replies) > 0, "No replies received" - - first_reply = replies[0] - assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" - assert first_reply.content, "First reply has no content" - assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" - assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" - assert first_reply.meta, "First reply has no metadata" - - -@pytest.fixture -def chat_messages(): - messages = [ - ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), - ChatMessage.from_user("What's the capital of France?"), - ] - return messages - - -class TestMetaLlama2ChatAdapter: - @pytest.mark.integration - def test_prepare_body_with_default_params(self) -> None: - # leave this test as integration because we really need only tokenizer from HF - # that way we can ensure prompt chat message formatting - layer = MetaLlama2ChatAdapter(truncate=True, generation_kwargs={}) - prompt = "Hello, how are you?" - expected_body = {"prompt": "[INST] Hello, how are you? [/INST]", "max_gen_len": 512} - - body = layer.prepare_body([ChatMessage.from_user(prompt)]) - - assert body == expected_body - - @pytest.mark.integration - def test_prepare_body_with_custom_inference_params(self) -> None: - # leave this test as integration because we really need only tokenizer from HF - # that way we can ensure prompt chat message formatting - layer = MetaLlama2ChatAdapter( - truncate=True, - generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]}, - ) - prompt = "Hello, how are you?" - - # expected body is different because stop_sequences and top_k are not supported by MetaLlama2 - expected_body = { - "prompt": "[INST] Hello, how are you? [/INST]", - "max_gen_len": 69, - "temperature": 0.7, - "top_p": 0.8, - } - - body = layer.prepare_body( - [ChatMessage.from_user(prompt)], - temperature=0.7, - top_p=0.8, - top_k=5, - max_gen_len=69, - stop_sequences=["CUSTOM_STOP"], - ) - - assert body == expected_body - - @pytest.mark.integration - def test_get_responses(self) -> None: - adapter = MetaLlama2ChatAdapter(truncate=True, generation_kwargs={}) - response_body = {"generation": "This is a single response."} - expected_response = "This is a single response." - response_message = adapter.get_responses(response_body) - # assert that the type of each item in the list is a ChatMessage - for message in response_message: - assert isinstance(message, ChatMessage) - - assert response_message == [ChatMessage.from_assistant(expected_response)] +class TestAmazonBedrockChatGeneratorInference: @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.integration From 35da4070d353fed0847b7175588af847b5e8fcc5 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 21 Nov 2024 17:01:23 +0100 Subject: [PATCH 02/17] Update models tested --- integrations/amazon_bedrock/tests/test_chat_generator.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index c349d035e..c9eaa8df7 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -5,13 +5,8 @@ from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" -MODELS_TO_TEST = ["anthropic.claude-3-5-sonnet-20240620-v1:0"] -MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-haiku-20240307-v1:0"] -MISTRAL_MODELS = [ - "mistral.mistral-7b-instruct-v0:2", - "mistral.mixtral-8x7b-instruct-v0:1", - "mistral.mistral-large-2402-v1:0", -] +MODELS_TO_TEST = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"] +MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-5-sonnet-20240620-v1:0"] @pytest.fixture From a0a1350c089498fc90d55473871ad9cfea59efc4 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 25 Nov 2024 15:22:02 +0100 Subject: [PATCH 03/17] Add tool support --- .../amazon_bedrock/chat/chat_generator.py | 114 +++++++++++------- .../tests/test_chat_generator.py | 73 ++++++++++- 2 files changed, 141 insertions(+), 46 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index bfae0317a..0313b145c 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -1,3 +1,4 @@ +import json import logging from typing import Any, Callable, Dict, List, Optional @@ -20,10 +21,10 @@ @component class AmazonBedrockChatGenerator: """ - Completes chats using LLMs hosted on Amazon Bedrock. + Completes chats using LLMs hosted on Amazon Bedrock available via the Bedrock Converse API. For example, to use the Anthropic Claude 3 Sonnet model, initialize this component with the - 'anthropic.claude-3-sonnet-20240229-v1:0' model name. + 'anthropic.claude-3-5-sonnet-20240620-v1:0' model name. ### Usage example @@ -36,7 +37,7 @@ class AmazonBedrockChatGenerator: ChatMessage.from_user("What's Natural Language Processing?")] - client = AmazonBedrockChatGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", + client = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0", streaming_callback=print_streaming_chunk) client.run(messages, generation_kwargs={"max_tokens": 512}) @@ -198,46 +199,59 @@ def run( generation_kwargs: Optional[Dict[str, Any]] = None, ): generation_kwargs = generation_kwargs or {} - inference_config = self.generation_kwargs.copy() - inference_config.update(generation_kwargs) - # Prepare system prompts if any - system_prompts: List[Dict[str, Any]] = [] # Type annotation added + # Merge generation_kwargs with defaults + merged_kwargs = self.generation_kwargs.copy() + merged_kwargs.update(generation_kwargs) + + # Extract known inference parameters + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html + inference_config = { + key: merged_kwargs.pop(key, None) + for key in ["maxTokens", "stopSequences", "temperature", "topP"] + if key in merged_kwargs + } + + # Extract tool configuration if present + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html + tool_config = merged_kwargs.pop("toolConfig", None) + + # Any remaining kwargs go to additionalModelRequestFields + additional_fields = merged_kwargs if merged_kwargs else None + + # Prepare system prompts and messages + system_prompts = [] if messages and messages[0].is_from(ChatRole.SYSTEM): - system_message = messages[0] - system_prompts = [{"text": system_message.content}] + system_prompts = [{"text": messages[0].content}] messages = messages[1:] - # Prepare messages - messages_list = [] - for msg in messages: - message_dict = {"role": msg.role.value, "content": [{"text": msg.content}]} - messages_list.append(message_dict) + messages_list = [ + {"role": msg.role.value, "content": [{"text": msg.content}]} + for msg in messages + ] try: - if streaming_callback or self.streaming_callback: - response = self.client.converse_stream( - modelId=self.model, - messages=messages_list, - system=system_prompts, # Now properly typed - inferenceConfig=inference_config, - ) + # Build API parameters + params = { + "modelId": self.model, + "messages": messages_list, + "system": system_prompts, + "inferenceConfig": inference_config, + } + if tool_config: + params["toolConfig"] = tool_config + if additional_fields: + params["additionalModelRequestFields"] = additional_fields + callback = streaming_callback or self.streaming_callback + if callback: + response = self.client.converse_stream(**params) response_stream = response.get("stream") if not response_stream: msg = "No stream found in the response." raise AmazonBedrockInferenceError(msg) - callback = streaming_callback or self.streaming_callback - if callback is None: # This should never happen due to the if condition above - msg = "No streaming callback provided" - raise ValueError(msg) replies = self.process_streaming_response(response_stream, callback) else: - response = self.client.converse( - modelId=self.model, - messages=messages_list, - system=system_prompts, # Now properly typed - inferenceConfig=inference_config, - ) + response = self.client.converse(**params) replies = self.extract_replies_from_response(response) except ClientError as exception: msg = f"Could not generate inference for Amazon Bedrock model {self.model} due: {exception}" @@ -251,25 +265,36 @@ def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[C message = response_body["output"]["message"] if message["role"] == "assistant": content_blocks = message["content"] - text = "" - for content_block in content_blocks: - if "text" in content_block: - text += content_block["text"] - # Convert usage format from Bedrock to OpenAI format - usage = response_body.get("usage", {}) - meta = { + # Common meta information + base_meta = { "model": self.model, "index": 0, "finish_reason": response_body.get("stopReason"), "usage": { - "prompt_tokens": usage.get("inputTokens", 0), - "completion_tokens": usage.get("outputTokens", 0), - "total_tokens": usage.get("totalTokens", 0), - }, + # OpenAI's format for usage for cross ChatGenerator compatibility + "prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0), + "completion_tokens": response_body.get("usage", {}).get("outputTokens", 0), + "total_tokens": response_body.get("usage", {}).get("totalTokens", 0), + } } - replies.append(ChatMessage.from_assistant(content=text, meta=meta)) + # Process each content block separately + for content_block in content_blocks: + if "text" in content_block: + replies.append( + ChatMessage.from_assistant( + content=content_block["text"], + meta=base_meta.copy() + ) + ) + elif "toolUse" in content_block: + replies.append( + ChatMessage.from_assistant( + content=json.dumps(content_block["toolUse"]), + meta={**base_meta.copy()} + ) + ) return replies def process_streaming_response( @@ -282,8 +307,6 @@ def process_streaming_response( } for event in response_stream: - # if "messageStart" in event: - # role = event["messageStart"]["role"] if "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] delta_text = delta.get("text", "") @@ -297,6 +320,7 @@ def process_streaming_response( metadata = event["metadata"] if "usage" in metadata: usage = metadata["usage"] + # use OpenAI's format for usage for cross ChatGenerator compatibility meta["usage"] = { "prompt_tokens": usage.get("inputTokens", 0), "completion_tokens": usage.get("outputTokens", 0), diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index c9eaa8df7..bc8c84e4e 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,3 +1,5 @@ +import json + import pytest from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk @@ -6,7 +8,7 @@ KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" MODELS_TO_TEST = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"] -MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-5-sonnet-20240620-v1:0"] +MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-5-sonnet-20240620-v1:0","cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"] @pytest.fixture @@ -187,3 +189,72 @@ def streaming_callback(chunk: StreamingChunk): assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" assert first_reply.meta, "First reply has no metadata" + + @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS) + @pytest.mark.integration + def test_tools_use(self, model_name): + """ + Test function calling with AWS Bedrock Anthropic adapter + """ + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html + tool_config = { + "tools": [ + { + "toolSpec": { + "name": "top_song", + "description": "Get the most popular song played on a radio station.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "sign": { + "type": "string", + "description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ and WKRP." + } + }, + "required": [ + "sign" + ] + } + } + } + } + ], + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + "toolChoice": {"auto": {}} + } + + messages = [] + messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?")) + client = AmazonBedrockChatGenerator(model=model_name) + response = client.run(messages=messages, generation_kwargs={"toolConfig": tool_config}) + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert first_reply.meta, "First reply has no metadata" + + + # Some models return thinking message as first and the second one as the tool call + if len(replies) > 1: + second_reply = replies[1] + assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" + assert second_reply.content, "Second reply has no content" + assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" + tool_call = json.loads(second_reply.content) + assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" + assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" + assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" + assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value" + else: + # case where the model returns the tool call as the first message + # double check that the tool call is correct + tool_call = json.loads(first_reply.content) + assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" + assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" + assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" + assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value" From d9babb94d131e573094cdaffd5281ffbf1faf2f5 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 26 Nov 2024 13:32:02 +0100 Subject: [PATCH 04/17] Update Amazon Bedrock model names in tests --- .../amazon_bedrock/tests/test_chat_generator.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index bc8c84e4e..132e0123f 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -25,7 +25,7 @@ def test_to_dict(mock_boto3_session): Test that the to_dict method returns the correct dictionary without aws credentials """ generator = AmazonBedrockChatGenerator( - model="anthropic.claude-v2", + model="cohere.command-r-plus-v1:0", generation_kwargs={"temperature": 0.7}, streaming_callback=print_streaming_chunk, boto3_config=boto3_config, @@ -38,7 +38,7 @@ def test_to_dict(mock_boto3_session): "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, - "model": "anthropic.claude-v2", + "model": "cohere.command-r-plus-v1:0", "generation_kwargs": {"temperature": 0.7}, "stop_words": [], "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", @@ -72,7 +72,7 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, - "model": "anthropic.claude-v2", + "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "truncate": True, @@ -80,7 +80,7 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any }, } ) - assert generator.model == "anthropic.claude-v2" + assert generator.model == "anthropic.claude-3-5-sonnet-20240620-v1:0" assert generator.streaming_callback == print_streaming_chunk assert generator.boto3_config == boto3_config @@ -91,10 +91,10 @@ def test_default_constructor(mock_boto3_session, set_env_variables): """ layer = AmazonBedrockChatGenerator( - model="anthropic.claude-v2", + model="anthropic.claude-3-5-sonnet-20240620-v1:0", ) - assert layer.model == "anthropic.claude-v2" + assert layer.model == "anthropic.claude-3-5-sonnet-20240620-v1:0" assert layer.truncate is True # assert mocked boto3 client called exactly once @@ -116,7 +116,7 @@ def test_constructor_with_generation_kwargs(mock_boto3_session): """ generation_kwargs = {"temperature": 0.7} - layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", generation_kwargs=generation_kwargs) + layer = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0", generation_kwargs=generation_kwargs) assert layer.generation_kwargs == generation_kwargs @@ -124,7 +124,7 @@ def test_constructor_with_truncate(mock_boto3_session): """ Test that truncate param is correctly set in the model constructor """ - layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", truncate=False) + layer = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0", truncate=False) assert layer.truncate is False From 2511db3b5c8e2885403acb505c61136f036c0fa6 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 26 Nov 2024 14:27:15 +0100 Subject: [PATCH 05/17] Support for tool streaming --- .../amazon_bedrock/chat/chat_generator.py | 80 ++++++++++++++----- .../tests/test_chat_generator.py | 73 ++++++++++++++++- 2 files changed, 134 insertions(+), 19 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 0313b145c..d50a76fc2 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -300,32 +300,76 @@ def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[C def process_streaming_response( self, response_stream, streaming_callback: Callable[[StreamingChunk], None] ) -> List[ChatMessage]: - content = "" - meta = { + replies = [] + current_content = "" + current_tool_use = None + base_meta = { "model": self.model, "index": 0, } for event in response_stream: - if "contentBlockDelta" in event: + if "contentBlockStart" in event: + # Reset accumulators for new message + current_content = "" + current_tool_use = None + block_start = event["contentBlockStart"] + if "start" in block_start and "toolUse" in block_start["start"]: + tool_start = block_start["start"]["toolUse"] + current_tool_use = { + "toolUseId": tool_start["toolUseId"], + "name": tool_start["name"], + "input": "" # Will accumulate deltas as string + } + + elif "contentBlockDelta" in event: delta = event["contentBlockDelta"]["delta"] - delta_text = delta.get("text", "") - if delta_text: - content += delta_text + if "text" in delta: + delta_text = delta["text"] + current_content += delta_text streaming_chunk = StreamingChunk(content=delta_text, meta=None) + # it only makes sense to call callback on text deltas streaming_callback(streaming_chunk) - if "messageStop" in event: - meta["finish_reason"] = event["messageStop"].get("stopReason") - if "metadata" in event: + elif "toolUse" in delta and current_tool_use: + # Accumulate tool use input deltas + current_tool_use["input"] += delta["toolUse"].get("input", "") + elif "contentBlockStop" in event: + if current_tool_use: + # Parse accumulated input if it's a JSON string + try: + input_json = json.loads(current_tool_use["input"]) + current_tool_use["input"] = input_json + except json.JSONDecodeError: + # Keep as string if not valid JSON + pass + + tool_content = json.dumps(current_tool_use) + replies.append( + ChatMessage.from_assistant( + content=tool_content, + meta=base_meta.copy() + ) + ) + elif current_content: + replies.append( + ChatMessage.from_assistant(content=current_content, meta=base_meta.copy()) + ) + + elif "messageStop" in event: + # not 100% correct for multiple messages but no way around it + for reply in replies: + reply.meta["finish_reason"] = event["messageStop"].get("stopReason") + + elif "metadata" in event: metadata = event["metadata"] - if "usage" in metadata: - usage = metadata["usage"] - # use OpenAI's format for usage for cross ChatGenerator compatibility - meta["usage"] = { - "prompt_tokens": usage.get("inputTokens", 0), - "completion_tokens": usage.get("outputTokens", 0), - "total_tokens": usage.get("totalTokens", 0), - } + # not 100% correct for multiple messages but no way around it + for reply in replies: + if "usage" in metadata: + usage = metadata["usage"] + reply.meta["usage"] = { + "prompt_tokens": usage.get("inputTokens", 0), + "completion_tokens": usage.get("outputTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + } - replies = [ChatMessage.from_assistant(content=content, meta=meta)] return replies diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 132e0123f..ee26b02f8 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -8,7 +8,10 @@ KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" MODELS_TO_TEST = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"] -MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-5-sonnet-20240620-v1:0","cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"] +MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"] + +# so far we've discovered these models support streaming and tool use +STREAMING_TOOL_MODELS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0"] @pytest.fixture @@ -239,6 +242,74 @@ def test_tools_use(self, model_name): assert first_reply.meta, "First reply has no metadata" + # Some models return thinking message as first and the second one as the tool call + if len(replies) > 1: + second_reply = replies[1] + assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance" + assert second_reply.content, "Second reply has no content" + assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant" + tool_call = json.loads(second_reply.content) + assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" + assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" + assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" + assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value" + else: + # case where the model returns the tool call as the first message + # double check that the tool call is correct + tool_call = json.loads(first_reply.content) + assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" + assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" + assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" + assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value" + + @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) + @pytest.mark.integration + def test_tools_use_with_streaming(self, model_name): + """ + Test function calling with AWS Bedrock Anthropic adapter + """ + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html + tool_config = { + "tools": [ + { + "toolSpec": { + "name": "top_song", + "description": "Get the most popular song played on a radio station.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "sign": { + "type": "string", + "description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ and WKRP." + } + }, + "required": [ + "sign" + ] + } + } + } + } + ], + # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + "toolChoice": {"auto": {}} + } + + messages = [] + messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?")) + client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=print_streaming_chunk) + response = client.run(messages=messages, generation_kwargs={"toolConfig": tool_config}) + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert first_reply.meta, "First reply has no metadata" + # Some models return thinking message as first and the second one as the tool call if len(replies) > 1: second_reply = replies[1] From 88bf808f2077cf75c26cb45b1b8e930e00efd13b Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 26 Nov 2024 15:06:49 +0100 Subject: [PATCH 06/17] Format --- .../amazon_bedrock/chat/chat_generator.py | 30 +++------- .../tests/test_chat_generator.py | 57 ++++++++++++------- 2 files changed, 44 insertions(+), 43 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index d50a76fc2..96653bafd 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -225,10 +225,7 @@ def run( system_prompts = [{"text": messages[0].content}] messages = messages[1:] - messages_list = [ - {"role": msg.role.value, "content": [{"text": msg.content}]} - for msg in messages - ] + messages_list = [{"role": msg.role.value, "content": [{"text": msg.content}]} for msg in messages] try: # Build API parameters @@ -276,23 +273,17 @@ def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[C "prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0), "completion_tokens": response_body.get("usage", {}).get("outputTokens", 0), "total_tokens": response_body.get("usage", {}).get("totalTokens", 0), - } + }, } # Process each content block separately for content_block in content_blocks: if "text" in content_block: - replies.append( - ChatMessage.from_assistant( - content=content_block["text"], - meta=base_meta.copy() - ) - ) + replies.append(ChatMessage.from_assistant(content=content_block["text"], meta=base_meta.copy())) elif "toolUse" in content_block: replies.append( ChatMessage.from_assistant( - content=json.dumps(content_block["toolUse"]), - meta={**base_meta.copy()} + content=json.dumps(content_block["toolUse"]), meta=base_meta.copy() ) ) return replies @@ -319,7 +310,7 @@ def process_streaming_response( current_tool_use = { "toolUseId": tool_start["toolUseId"], "name": tool_start["name"], - "input": "" # Will accumulate deltas as string + "input": "", # Will accumulate deltas as string } elif "contentBlockDelta" in event: @@ -344,16 +335,9 @@ def process_streaming_response( pass tool_content = json.dumps(current_tool_use) - replies.append( - ChatMessage.from_assistant( - content=tool_content, - meta=base_meta.copy() - ) - ) + replies.append(ChatMessage.from_assistant(content=tool_content, meta=base_meta.copy())) elif current_content: - replies.append( - ChatMessage.from_assistant(content=current_content, meta=base_meta.copy()) - ) + replies.append(ChatMessage.from_assistant(content=current_content, meta=base_meta.copy())) elif "messageStop" in event: # not 100% correct for multiple messages but no way around it diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index ee26b02f8..6913ddc9a 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -7,8 +7,16 @@ from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" -MODELS_TO_TEST = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"] -MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0", "mistral.mistral-large-2402-v1:0"] +MODELS_TO_TEST = [ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "cohere.command-r-plus-v1:0", + "mistral.mistral-large-2402-v1:0", +] +MODELS_TO_TEST_WITH_TOOLS = [ + "anthropic.claude-3-5-sonnet-20240620-v1:0", + "cohere.command-r-plus-v1:0", + "mistral.mistral-large-2402-v1:0", +] # so far we've discovered these models support streaming and tool use STREAMING_TOOL_MODELS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0"] @@ -119,7 +127,9 @@ def test_constructor_with_generation_kwargs(mock_boto3_session): """ generation_kwargs = {"temperature": 0.7} - layer = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0", generation_kwargs=generation_kwargs) + layer = AmazonBedrockChatGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", generation_kwargs=generation_kwargs + ) assert layer.generation_kwargs == generation_kwargs @@ -212,19 +222,19 @@ def test_tools_use(self, model_name): "properties": { "sign": { "type": "string", - "description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ and WKRP." + "description": "The call sign for the radio station " + "for which you want the most popular song. " + "Example calls signs are WZPZ and WKRP.", } }, - "required": [ - "sign" - ] + "required": ["sign"], } - } + }, } } ], # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html - "toolChoice": {"auto": {}} + "toolChoice": {"auto": {}}, } messages = [] @@ -241,7 +251,6 @@ def test_tools_use(self, model_name): assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" assert first_reply.meta, "First reply has no metadata" - # Some models return thinking message as first and the second one as the tool call if len(replies) > 1: second_reply = replies[1] @@ -252,7 +261,9 @@ def test_tools_use(self, model_name): assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" else: # case where the model returns the tool call as the first message # double check that the tool call is correct @@ -260,7 +271,9 @@ def test_tools_use(self, model_name): assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS) @pytest.mark.integration @@ -281,19 +294,19 @@ def test_tools_use_with_streaming(self, model_name): "properties": { "sign": { "type": "string", - "description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ and WKRP." + "description": "The call sign for the radio station " + "for which you want the most popular song. Example " + "calls signs are WZPZ and WKRP.", } }, - "required": [ - "sign" - ] + "required": ["sign"], } - } + }, } } ], # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html - "toolChoice": {"auto": {}} + "toolChoice": {"auto": {}}, } messages = [] @@ -320,7 +333,9 @@ def test_tools_use_with_streaming(self, model_name): assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" else: # case where the model returns the tool call as the first message # double check that the tool call is correct @@ -328,4 +343,6 @@ def test_tools_use_with_streaming(self, model_name): assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key" assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value" assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key" - assert tool_call["input"]["sign"] == "WZPZ", f"Tool call {tool_call} does not contain the correct 'input' value" + assert ( + tool_call["input"]["sign"] == "WZPZ" + ), f"Tool call {tool_call} does not contain the correct 'input' value" From bdf6c34d26c7fb9d81420e330dff4dcddf2556c6 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 5 Dec 2024 15:23:32 +0100 Subject: [PATCH 07/17] Minot test updates --- .../amazon_bedrock/tests/test_chat_generator.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 6913ddc9a..ed2b80250 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,4 +1,5 @@ import json +from typing import Any, Dict, Optional import pytest from haystack.components.generators.utils import print_streaming_chunk @@ -30,8 +31,16 @@ def chat_messages(): ] return messages - -def test_to_dict(mock_boto3_session): +@pytest.mark.parametrize( + "boto3_config", + [ + None, + { + "read_timeout": 1000, + }, + ], +) +def test_to_dict(mock_boto3_session, boto3_config): """ Test that the to_dict method returns the correct dictionary without aws credentials """ From 6338e746eef18ac3c4ebf9411eee4aeb11420db5 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 5 Dec 2024 15:29:39 +0100 Subject: [PATCH 08/17] Lint --- integrations/amazon_bedrock/tests/test_chat_generator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index ed2b80250..345902b42 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -31,6 +31,7 @@ def chat_messages(): ] return messages + @pytest.mark.parametrize( "boto3_config", [ From fe17cfb857f941cad21848832f6bf1183cfaa804 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Dec 2024 14:40:58 +0100 Subject: [PATCH 09/17] Remove truncate init parameter --- .../generators/amazon_bedrock/chat/chat_generator.py | 10 +++------- .../amazon_bedrock/tests/test_chat_generator.py | 11 ----------- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 96653bafd..fc701901b 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -67,8 +67,7 @@ def __init__( aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - truncate: Optional[bool] = True, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, boto3_config: Optional[Dict[str, Any]] = None, ): """ @@ -102,7 +101,6 @@ def __init__( function that handles the streaming chunks. The callback function receives a [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and switches the streaming mode on. - :param truncate: Whether to truncate the prompt messages or not. :param boto3_config: The configuration for the boto3 client. :raises ValueError: If the model name is empty or None. @@ -119,8 +117,7 @@ def __init__( self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name self.stop_words = stop_words or [] - self.streaming_callback = streaming_callback - self.truncate = truncate + self.streaming_callback = streaming_callback self.boto3_config = boto3_config def resolve_secret(secret: Optional[Secret]) -> Optional[str]: @@ -167,8 +164,7 @@ def to_dict(self) -> Dict[str, Any]: model=self.model, stop_words=self.stop_words, generation_kwargs=self.generation_kwargs, - streaming_callback=callback_name, - truncate=self.truncate, + streaming_callback=callback_name, boto3_config=self.boto3_config, ) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 345902b42..c006b2299 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -63,7 +63,6 @@ def test_to_dict(mock_boto3_session, boto3_config): "generation_kwargs": {"temperature": 0.7}, "stop_words": [], "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "truncate": True, "boto3_config": boto3_config, }, } @@ -96,7 +95,6 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "generation_kwargs": {"temperature": 0.7}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "truncate": True, "boto3_config": boto3_config, }, } @@ -116,7 +114,6 @@ def test_default_constructor(mock_boto3_session, set_env_variables): ) assert layer.model == "anthropic.claude-3-5-sonnet-20240620-v1:0" - assert layer.truncate is True # assert mocked boto3 client called exactly once mock_boto3_session.assert_called_once() @@ -143,14 +140,6 @@ def test_constructor_with_generation_kwargs(mock_boto3_session): assert layer.generation_kwargs == generation_kwargs -def test_constructor_with_truncate(mock_boto3_session): - """ - Test that truncate param is correctly set in the model constructor - """ - layer = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0", truncate=False) - assert layer.truncate is False - - def test_constructor_with_empty_model(): """ Test that the constructor raises an error when the model is empty From 39787b975b97ffedc7ffbcd6da2cf89bc6f7c374 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Dec 2024 14:42:44 +0100 Subject: [PATCH 10/17] Pull try down --- .../amazon_bedrock/chat/chat_generator.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index fc701901b..0329685fb 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -223,19 +223,21 @@ def run( messages_list = [{"role": msg.role.value, "content": [{"text": msg.content}]} for msg in messages] + # Build API parameters + params = { + "modelId": self.model, + "messages": messages_list, + "system": system_prompts, + "inferenceConfig": inference_config, + } + if tool_config: + params["toolConfig"] = tool_config + if additional_fields: + params["additionalModelRequestFields"] = additional_fields + + callback = streaming_callback or self.streaming_callback + try: - # Build API parameters - params = { - "modelId": self.model, - "messages": messages_list, - "system": system_prompts, - "inferenceConfig": inference_config, - } - if tool_config: - params["toolConfig"] = tool_config - if additional_fields: - params["additionalModelRequestFields"] = additional_fields - callback = streaming_callback or self.streaming_callback if callback: response = self.client.converse_stream(**params) response_stream = response.get("stream") From 2adb50b0979412751025d3980fb188a880440d78 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Dec 2024 14:46:09 +0100 Subject: [PATCH 11/17] Add extract_replies_from_response unit test --- .../tests/test_chat_generator.py | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index c006b2299..aa840d86f 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -345,3 +345,105 @@ def test_tools_use_with_streaming(self, model_name): assert ( tool_call["input"]["sign"] == "WZPZ" ), f"Tool call {tool_call} does not contain the correct 'input' value" + + def test_extract_replies_from_response(self, mock_boto3_session): + """ + Test that extract_replies_from_response correctly processes both text and tool use responses + """ + generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") + + # Test case 1: Simple text response + text_response = { + "output": { + "message": { + "role": "assistant", + "content": [{"text": "This is a test response"}] + } + }, + "stopReason": "complete", + "usage": { + "inputTokens": 10, + "outputTokens": 20, + "totalTokens": 30 + } + } + + replies = generator.extract_replies_from_response(text_response) + assert len(replies) == 1 + assert replies[0].content == "This is a test response" + assert replies[0].role == ChatRole.ASSISTANT + assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert replies[0].meta["finish_reason"] == "complete" + assert replies[0].meta["usage"] == { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30 + } + + # Test case 2: Tool use response + tool_response = { + "output": { + "message": { + "role": "assistant", + "content": [{ + "toolUse": { + "toolUseId": "123", + "name": "test_tool", + "input": {"key": "value"} + } + }] + } + }, + "stopReason": "tool_call", + "usage": { + "inputTokens": 15, + "outputTokens": 25, + "totalTokens": 40 + } + } + + replies = generator.extract_replies_from_response(tool_response) + assert len(replies) == 1 + tool_content = json.loads(replies[0].content) + assert tool_content["toolUseId"] == "123" + assert tool_content["name"] == "test_tool" + assert tool_content["input"] == {"key": "value"} + assert replies[0].meta["finish_reason"] == "tool_call" + assert replies[0].meta["usage"] == { + "prompt_tokens": 15, + "completion_tokens": 25, + "total_tokens": 40 + } + + # Test case 3: Mixed content response + mixed_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + {"text": "Let me help you with that. I'll use the search tool to find the answer."}, + { + "toolUse": { + "toolUseId": "456", + "name": "search_tool", + "input": {"query": "test"} + } + } + ] + } + }, + "stopReason": "complete", + "usage": { + "inputTokens": 25, + "outputTokens": 35, + "totalTokens": 60 + } + } + + replies = generator.extract_replies_from_response(mixed_response) + assert len(replies) == 2 + assert replies[0].content == "Let me help you with that." + tool_content = json.loads(replies[1].content) + assert tool_content["toolUseId"] == "456" + assert tool_content["name"] == "search_tool" + assert tool_content["input"] == {"query": "test"} From c7f51a1ee9bf3b30549ea758e25b428b0b8f97e6 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Dec 2024 14:53:20 +0100 Subject: [PATCH 12/17] Add process_streaming_response unit test --- .../tests/test_chat_generator.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index aa840d86f..ab980603b 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -447,3 +447,52 @@ def test_extract_replies_from_response(self, mock_boto3_session): assert tool_content["toolUseId"] == "456" assert tool_content["name"] == "search_tool" assert tool_content["input"] == {"query": "test"} + + def test_process_streaming_response(self, mock_boto3_session): + """ + Test that process_streaming_response correctly handles streaming events and accumulates responses + """ + generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") + + streaming_chunks = [] + def test_callback(chunk: StreamingChunk): + streaming_chunks.append(chunk) + + # Simulate a stream of events for both text and tool use + events = [ + {"contentBlockStart": {"start": {"text": ""}}}, + {"contentBlockDelta": {"delta": {"text": "Let me "}}}, + {"contentBlockDelta": {"delta": {"text": "help you."}}}, + {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "search_tool"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"query":'}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '"test"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "complete"}}, + {"metadata": {"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}}} + ] + + replies = generator.process_streaming_response(events, test_callback) + + # Verify streaming chunks were received for text content + assert len(streaming_chunks) == 2 + assert streaming_chunks[0].content == "Let me " + assert streaming_chunks[1].content == "help you." + + # Verify final replies + assert len(replies) == 2 + # Check text reply + assert replies[0].content == "Let me help you." + assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert replies[0].meta["finish_reason"] == "complete" + assert replies[0].meta["usage"] == { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30 + } + + # Check tool use reply + tool_content = json.loads(replies[1].content) + assert tool_content["toolUseId"] == "123" + assert tool_content["name"] == "search_tool" + assert tool_content["input"] == {"query": "test"} From aa7c5c06fa59b166aa14fbb52bca74a7bdf082c1 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Dec 2024 14:56:21 +0100 Subject: [PATCH 13/17] Lint --- .../amazon_bedrock/chat/chat_generator.py | 6 +- .../tests/test_chat_generator.py | 64 ++++--------------- 2 files changed, 15 insertions(+), 55 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 0329685fb..1153e581e 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -67,7 +67,7 @@ def __init__( aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 generation_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, boto3_config: Optional[Dict[str, Any]] = None, ): """ @@ -117,7 +117,7 @@ def __init__( self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name self.stop_words = stop_words or [] - self.streaming_callback = streaming_callback + self.streaming_callback = streaming_callback self.boto3_config = boto3_config def resolve_secret(secret: Optional[Secret]) -> Optional[str]: @@ -164,7 +164,7 @@ def to_dict(self) -> Dict[str, Any]: model=self.model, stop_words=self.stop_words, generation_kwargs=self.generation_kwargs, - streaming_callback=callback_name, + streaming_callback=callback_name, boto3_config=self.boto3_config, ) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index ab980603b..5ccc3083e 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -354,18 +354,9 @@ def test_extract_replies_from_response(self, mock_boto3_session): # Test case 1: Simple text response text_response = { - "output": { - "message": { - "role": "assistant", - "content": [{"text": "This is a test response"}] - } - }, + "output": {"message": {"role": "assistant", "content": [{"text": "This is a test response"}]}}, "stopReason": "complete", - "usage": { - "inputTokens": 10, - "outputTokens": 20, - "totalTokens": 30 - } + "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}, } replies = generator.extract_replies_from_response(text_response) @@ -374,32 +365,18 @@ def test_extract_replies_from_response(self, mock_boto3_session): assert replies[0].role == ChatRole.ASSISTANT assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" assert replies[0].meta["finish_reason"] == "complete" - assert replies[0].meta["usage"] == { - "prompt_tokens": 10, - "completion_tokens": 20, - "total_tokens": 30 - } + assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} # Test case 2: Tool use response tool_response = { "output": { "message": { "role": "assistant", - "content": [{ - "toolUse": { - "toolUseId": "123", - "name": "test_tool", - "input": {"key": "value"} - } - }] + "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"key": "value"}}}], } }, "stopReason": "tool_call", - "usage": { - "inputTokens": 15, - "outputTokens": 25, - "totalTokens": 40 - } + "usage": {"inputTokens": 15, "outputTokens": 25, "totalTokens": 40}, } replies = generator.extract_replies_from_response(tool_response) @@ -409,11 +386,7 @@ def test_extract_replies_from_response(self, mock_boto3_session): assert tool_content["name"] == "test_tool" assert tool_content["input"] == {"key": "value"} assert replies[0].meta["finish_reason"] == "tool_call" - assert replies[0].meta["usage"] == { - "prompt_tokens": 15, - "completion_tokens": 25, - "total_tokens": 40 - } + assert replies[0].meta["usage"] == {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40} # Test case 3: Mixed content response mixed_response = { @@ -422,22 +395,12 @@ def test_extract_replies_from_response(self, mock_boto3_session): "role": "assistant", "content": [ {"text": "Let me help you with that. I'll use the search tool to find the answer."}, - { - "toolUse": { - "toolUseId": "456", - "name": "search_tool", - "input": {"query": "test"} - } - } - ] + {"toolUse": {"toolUseId": "456", "name": "search_tool", "input": {"query": "test"}}}, + ], } }, "stopReason": "complete", - "usage": { - "inputTokens": 25, - "outputTokens": 35, - "totalTokens": 60 - } + "usage": {"inputTokens": 25, "outputTokens": 35, "totalTokens": 60}, } replies = generator.extract_replies_from_response(mixed_response) @@ -455,6 +418,7 @@ def test_process_streaming_response(self, mock_boto3_session): generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0") streaming_chunks = [] + def test_callback(chunk: StreamingChunk): streaming_chunks.append(chunk) @@ -469,7 +433,7 @@ def test_callback(chunk: StreamingChunk): {"contentBlockDelta": {"delta": {"toolUse": {"input": '"test"}'}}}}, {"contentBlockStop": {}}, {"messageStop": {"stopReason": "complete"}}, - {"metadata": {"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}}} + {"metadata": {"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}}}, ] replies = generator.process_streaming_response(events, test_callback) @@ -485,11 +449,7 @@ def test_callback(chunk: StreamingChunk): assert replies[0].content == "Let me help you." assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0" assert replies[0].meta["finish_reason"] == "complete" - assert replies[0].meta["usage"] == { - "prompt_tokens": 10, - "completion_tokens": 20, - "total_tokens": 30 - } + assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} # Check tool use reply tool_content = json.loads(replies[1].content) From 405865b26107677897586fa19152c5f8022bcd3d Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Dec 2024 15:00:41 +0100 Subject: [PATCH 14/17] Small test fix --- integrations/amazon_bedrock/tests/test_chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 5ccc3083e..8eb29729c 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -405,7 +405,7 @@ def test_extract_replies_from_response(self, mock_boto3_session): replies = generator.extract_replies_from_response(mixed_response) assert len(replies) == 2 - assert replies[0].content == "Let me help you with that." + assert replies[0].content == "Let me help you with that. I'll use the search tool to find the answer." tool_content = json.loads(replies[1].content) assert tool_content["toolUseId"] == "456" assert tool_content["name"] == "search_tool" From 3990dfdb8f7cd153f077fffcad42676f3f18f8d4 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Dec 2024 15:07:46 +0100 Subject: [PATCH 15/17] Use EventStream from botocore --- .../generators/amazon_bedrock/chat/chat_generator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 1153e581e..2b8d33c61 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional from botocore.config import Config +from botocore.eventstream import EventStream from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk @@ -240,7 +241,7 @@ def run( try: if callback: response = self.client.converse_stream(**params) - response_stream = response.get("stream") + response_stream: EventStream = response.get("stream") if not response_stream: msg = "No stream found in the response." raise AmazonBedrockInferenceError(msg) @@ -287,7 +288,7 @@ def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[C return replies def process_streaming_response( - self, response_stream, streaming_callback: Callable[[StreamingChunk], None] + self, response_stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] ) -> List[ChatMessage]: replies = [] current_content = "" From 0a318d1fd0cd19c5ff69e48c9e7aabf28a05a8c5 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Dec 2024 15:21:44 +0100 Subject: [PATCH 16/17] Update integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py Co-authored-by: Stefano Fiorucci --- .../components/generators/amazon_bedrock/chat/chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 2b8d33c61..a834fec87 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -219,7 +219,7 @@ def run( # Prepare system prompts and messages system_prompts = [] if messages and messages[0].is_from(ChatRole.SYSTEM): - system_prompts = [{"text": messages[0].content}] + system_prompts = [{"text": messages[0].text}] messages = messages[1:] messages_list = [{"role": msg.role.value, "content": [{"text": msg.content}]} for msg in messages] From a16d3d78fba69e569d274b1fa25f7e86a18a7eb6 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Dec 2024 15:21:52 +0100 Subject: [PATCH 17/17] Update integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py Co-authored-by: Stefano Fiorucci --- .../components/generators/amazon_bedrock/chat/chat_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index a834fec87..499fe1c24 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -222,7 +222,7 @@ def run( system_prompts = [{"text": messages[0].text}] messages = messages[1:] - messages_list = [{"role": msg.role.value, "content": [{"text": msg.content}]} for msg in messages] + messages_list = [{"role": msg.role.value, "content": [{"text": msg.text}]} for msg in messages] # Build API parameters params = {