diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py
deleted file mode 100644
index cbb5ee370..000000000
--- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py
+++ /dev/null
@@ -1,569 +0,0 @@
-import json
-import logging
-import os
-from abc import ABC, abstractmethod
-from typing import Any, Callable, ClassVar, Dict, List, Optional
-
-from botocore.eventstream import EventStream
-from haystack.components.generators.openai_utils import _convert_message_to_openai_format
-from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk
-from transformers import AutoTokenizer, PreTrainedTokenizer
-
-from haystack_integrations.components.generators.amazon_bedrock.handlers import DefaultPromptHandler
-
-logger = logging.getLogger(__name__)
-
-
-class BedrockModelChatAdapter(ABC):
- """
- Base class for Amazon Bedrock chat model adapters.
-
- Each subclass of this class is designed to address the unique specificities of a particular chat LLM it adapts,
- focusing on preparing the requests and extracting the responses from the Amazon Bedrock hosted chat LLMs.
- """
-
- def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None:
- """
- Initializes the chat adapter with the truncate parameter and generation kwargs.
- """
- self.generation_kwargs = generation_kwargs
- self.truncate = truncate
-
- @abstractmethod
- def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]:
- """
- Prepares the body for the Amazon Bedrock request.
- Subclasses should override this method to package the chat messages into the request.
-
- :param messages: The chat messages to package into the request.
- :param inference_kwargs: Additional inference kwargs to use.
- :returns: The prepared body.
- """
-
- def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
- """
- Extracts the responses from the Amazon Bedrock response.
-
- :param response_body: The response body.
- :returns: The extracted responses.
- """
- return self._extract_messages_from_response(response_body)
-
- def get_stream_responses(
- self, stream: EventStream, streaming_callback: Callable[[StreamingChunk], None]
- ) -> List[ChatMessage]:
- streaming_chunks: List[StreamingChunk] = []
- last_decoded_chunk: Dict[str, Any] = {}
- for event in stream:
- chunk = event.get("chunk")
- if chunk:
- last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8"))
- streaming_chunk = self._build_streaming_chunk(last_decoded_chunk)
- streaming_callback(streaming_chunk) # callback the stream handler with StreamingChunk
- streaming_chunks.append(streaming_chunk)
- responses = ["".join(chunk.content for chunk in streaming_chunks).lstrip()]
- return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses]
-
- @staticmethod
- def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any], allowed_params: List[str]) -> None:
- """
- Updates target_dict with values from updates_dict. Merges lists instead of overriding them.
-
- :param target_dict: The dictionary to update.
- :param updates_dict: The dictionary with updates.
- :param allowed_params: The list of allowed params to use.
- """
- for key, value in updates_dict.items():
- if key not in allowed_params:
- logger.warning(f"Parameter '{key}' is not allowed and will be ignored.")
- continue
- if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list):
- # Merge lists and remove duplicates
- target_dict[key] = sorted(set(target_dict[key] + value))
- else:
- # Override the value in target_dict
- target_dict[key] = value
-
- def _get_params(
- self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any], allowed_params: List[str]
- ) -> Dict[str, Any]:
- """
- Merges params from inference_kwargs with the default params and self.generation_kwargs.
- Uses a helper function to merge lists or override values as necessary.
-
- :param inference_kwargs: The inference kwargs to merge.
- :param default_params: The default params to start with.
- :param allowed_params: The list of allowed params to use.
- :returns: The merged params.
- """
- # Start with a copy of default_params
- kwargs = default_params.copy()
-
- # Update the default params with self.generation_kwargs and finally inference_kwargs
- self._update_params(kwargs, self.generation_kwargs, allowed_params)
- self._update_params(kwargs, inference_kwargs, allowed_params)
-
- return kwargs
-
- def _ensure_token_limit(self, prompt: str) -> str:
- """
- Ensures that the prompt is within the token limit for the model.
- :param prompt: The prompt to check.
- :returns: The resized prompt.
- """
- resize_info = self.check_prompt(prompt)
- if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
- logger.warning(
- "The prompt was truncated from %s tokens to %s tokens so that the prompt length and "
- "the answer length (%s tokens) fit within the model's max token limit (%s tokens). "
- "Shorten the prompt or it will be cut off.",
- resize_info["prompt_length"],
- max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore
- resize_info["max_length"],
- resize_info["model_max_length"],
- )
- return str(resize_info["resized_prompt"])
-
- @abstractmethod
- def check_prompt(self, prompt: str) -> Dict[str, Any]:
- """
- Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated.
-
- :param prompt: The prompt to check.
- :returns: A dictionary containing the resized prompt and additional information.
- """
-
- @abstractmethod
- def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
- """
- Extracts the messages from the response body.
-
- :param response_body: The response body.
- :returns: The extracted ChatMessage list.
- """
-
- @abstractmethod
- def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
- """
- Extracts the content and meta from a streaming chunk.
-
- :param chunk: The streaming chunk as dict.
- :returns: A StreamingChunk object.
- """
-
-
-class AnthropicClaudeChatAdapter(BedrockModelChatAdapter):
- """
- Model adapter for the Anthropic Claude chat model.
- """
-
- # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
- ALLOWED_PARAMS: ClassVar[List[str]] = [
- "anthropic_version",
- "max_tokens",
- "stop_sequences",
- "temperature",
- "top_p",
- "top_k",
- "system",
- "tools",
- "tool_choice",
- ]
-
- def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]):
- """
- Initializes the Anthropic Claude chat adapter.
-
- :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit.
- :param generation_kwargs: The generation kwargs.
- """
- super().__init__(truncate, generation_kwargs)
-
- # We pop the model_max_length as it is not sent to the model
- # but used to truncate the prompt if needed
- # Anthropic Claude has a limit of at least 100000 tokens
- # https://docs.anthropic.com/claude/reference/input-and-output-sizes
- model_max_length = self.generation_kwargs.pop("model_max_length", 100000)
-
- # Truncate prompt if prompt tokens > model_max_length-max_length
- # (max_length is the length of the generated text)
- # TODO use Anthropic tokenizer to get the precise prompt length
- # See https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#token-counting
- self.prompt_handler = DefaultPromptHandler(
- tokenizer="gpt2",
- model_max_length=model_max_length,
- max_length=self.generation_kwargs.get("max_tokens") or 512,
- )
-
- def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]:
- """
- Prepares the body for the Anthropic Claude request.
-
- :param messages: The chat messages to package into the request.
- :param inference_kwargs: Additional inference kwargs to use.
- :returns: The prepared body.
- """
- default_params = {
- "anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31",
- "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required
- }
-
- # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it
- stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", [])
- if stop_sequences:
- inference_kwargs["stop_sequences"] = stop_sequences
- # pop stream kwarg from inference_kwargs as Anthropic does not support it (if provided)
- inference_kwargs.pop("stream", None)
- params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS)
- body = {**self.prepare_chat_messages(messages=messages), **params}
- return body
-
- def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]:
- """
- Prepares the chat messages for the Anthropic Claude request.
-
- :param messages: The chat messages to prepare.
- :returns: The prepared chat messages as a dictionary.
- """
- body: Dict[str, Any] = {}
- system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None
- body["messages"] = [
- self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT)
- ]
- if system:
- body["system"] = system
- # Ensure token limit for each message in the body
- if self.truncate:
- for message in body["messages"]:
- for content in message["content"]:
- content["text"] = self._ensure_token_limit(content["text"])
- return body
-
- def check_prompt(self, prompt: str) -> Dict[str, Any]:
- """
- Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated.
-
- :param prompt: The prompt to check.
- :returns: A dictionary containing the resized prompt and additional information.
- """
- return self.prompt_handler(prompt)
-
- def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
- """
- Extracts the messages from the response body.
-
- :param response_body: The response body.
- :return: The extracted ChatMessage list.
- """
- messages: List[ChatMessage] = []
- if response_body.get("type") == "message":
- if response_body.get("stop_reason") == "tool_use": # If `tool_use` we only keep the tool_use content
- for content in response_body["content"]:
- if content.get("type") == "tool_use":
- meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]}
- json_answer = json.dumps(content)
- messages.append(ChatMessage.from_assistant(json_answer, meta=meta))
- else: # For other stop_reason, return all text content
- for content in response_body["content"]:
- if content.get("type") == "text":
- meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]}
- messages.append(ChatMessage.from_assistant(content["text"], meta=meta))
-
- return messages
-
- def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
- """
- Extracts the content and meta from a streaming chunk.
-
- :param chunk: The streaming chunk as dict.
- :returns: A StreamingChunk object.
- """
- if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta":
- return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk)
- return StreamingChunk(content="", meta=chunk)
-
- def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]:
- """
- Convert a ChatMessage to a dictionary with the content and role fields.
- :param m: The ChatMessage to convert.
- :return: The dictionary with the content and role fields.
- """
- return {"content": [{"type": "text", "text": m.content}], "role": m.role.value}
-
-
-class MistralChatAdapter(BedrockModelChatAdapter):
- """
- Model adapter for the Mistral chat model.
- """
-
- chat_template = """
- {% if messages[0]['role'] == 'system' %}
- {% set loop_messages = messages[1:] %}
- {% set system_message = messages[0]['content'] %}
- {% else %}
- {% set loop_messages = messages %}
- {% set system_message = false %}
- {% endif %}
- {{bos_token}}
- {% for message in loop_messages %}
- {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
- {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
- {% endif %}
- {% if loop.index0 == 0 and system_message != false %}
- {% set content = system_message + '\n' + message['content'] %}
- {% else %}
- {% set content = message['content'] %}
- {% endif %}
- {% if message['role'] == 'user' %}
- {{ '[INST] ' + content.strip() + ' [/INST]' }}
- {% elif message['role'] == 'assistant' %}
- {{ content.strip() + eos_token }}
- {% endif %}
- {% endfor %}
- """
- chat_template = "".join(line.strip() for line in chat_template.splitlines())
-
- # the above template was designed to match https://docs.mistral.ai/models/#chat-template
- # and to support system messages, otherwise we could use the default mistral chat template
- # available on HF infrastructure
-
- # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
- ALLOWED_PARAMS: ClassVar[List[str]] = [
- "max_tokens",
- "safe_prompt",
- "random_seed",
- "temperature",
- "top_p",
- ]
-
- def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]):
- """
- Initializes the Mistral chat adapter.
- :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit.
- :param generation_kwargs: The generation kwargs.
- """
- super().__init__(truncate, generation_kwargs)
-
- # We pop the model_max_length as it is not sent to the model
- # but used to truncate the prompt if needed
- # Mistral has a limit of at least 32000 tokens
- model_max_length = self.generation_kwargs.pop("model_max_length", 32000)
-
- # Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer
- # a) we should get good estimates for the prompt length
- # b) we can use apply_chat_template with the template above to delineate ChatMessages
- # Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model.
- tokenizer: PreTrainedTokenizer
- if os.environ.get("HF_TOKEN") or os.environ.get("HF_API_TOKEN"):
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
- else:
- tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
- logger.warning(
- "Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for "
- "estimating the prompt length because no HF_TOKEN was found. Using "
- "NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var "
- "HF_TOKEN containing a Hugging Face token and make sure you have access to the model."
- )
-
- self.prompt_handler = DefaultPromptHandler(
- tokenizer=tokenizer,
- model_max_length=model_max_length,
- max_length=self.generation_kwargs.get("max_tokens") or 512,
- )
-
- def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]:
- """
- Prepares the body for the Mistral request.
-
- :param messages: The chat messages to package into the request.
- :param inference_kwargs: Additional inference kwargs to use.
- :returns: The prepared body.
- """
- default_params = {
- "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required
- }
- # replace stop_words from inference_kwargs with stop, as this is Mistral specific parameter
- stop_words = inference_kwargs.pop("stop_words", [])
- if stop_words:
- inference_kwargs["stop"] = stop_words
-
- # pop stream kwarg from inference_kwargs as Mistral does not support it (if provided)
- inference_kwargs.pop("stream", None)
-
- params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS)
- body = {"prompt": self.prepare_chat_messages(messages=messages), **params}
- return body
-
- def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
- """
- Prepares the chat messages for the Mistral request.
-
- :param messages: The chat messages to prepare.
- :returns: The prepared chat messages as a string.
- """
- # it would be great to use the default mistral chat template, but it doesn't support system messages
- # the class variable defined chat_template is a workaround to support system messages
- # default is https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
- # but we'll use our custom chat template
- prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(
- conversation=[_convert_message_to_openai_format(m) for m in messages],
- tokenize=False,
- chat_template=self.chat_template,
- )
- if self.truncate:
- prepared_prompt = self._ensure_token_limit(prepared_prompt)
- return prepared_prompt
-
- def check_prompt(self, prompt: str) -> Dict[str, Any]:
- """
- Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated.
-
- :param prompt: The prompt to check.
- :returns: A dictionary containing the resized prompt and additional information.
- """
- return self.prompt_handler(prompt)
-
- def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
- """
- Extracts the messages from the response body.
-
- :param response_body: The response body.
- :return: The extracted ChatMessage list.
- """
- messages: List[ChatMessage] = []
- responses = response_body.get("outputs", [])
- for response in responses:
- meta = {k: v for k, v in response.items() if k not in ["text"]}
- messages.append(ChatMessage.from_assistant(response["text"], meta=meta))
- return messages
-
- def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
- """
- Extracts the content and meta from a streaming chunk.
-
- :param chunk: The streaming chunk as dict.
- :returns: A StreamingChunk object.
- """
- response_chunk = chunk.get("outputs", [])
- if response_chunk:
- return StreamingChunk(content=response_chunk[0].get("text", ""), meta=chunk)
- return StreamingChunk(content="", meta=chunk)
-
-
-class MetaLlama2ChatAdapter(BedrockModelChatAdapter):
- """
- Model adapter for the Meta Llama 2 models.
- """
-
- # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
- ALLOWED_PARAMS: ClassVar[List[str]] = ["max_gen_len", "temperature", "top_p"]
-
- chat_template = (
- "{% if messages[0]['role'] == 'system' %}"
- "{% set loop_messages = messages[1:] %}"
- "{% set system_message = messages[0]['content'] %}"
- "{% else %}"
- "{% set loop_messages = messages %}"
- "{% set system_message = false %}"
- "{% endif %}"
- "{% for message in loop_messages %}"
- "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
- "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
- "{% endif %}"
- "{% if loop.index0 == 0 and system_message != false %}"
- "{% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %}"
- "{% else %}"
- "{% set content = message['content'] %}"
- "{% endif %}"
- "{% if message['role'] == 'user' %}"
- "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
- "{% elif message['role'] == 'assistant' %}"
- "{{ ' ' + content.strip() + ' ' + eos_token }}"
- "{% endif %}"
- "{% endfor %}"
- )
-
- def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None:
- """
- Initializes the Meta Llama 2 chat adapter.
- :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit.
- :param generation_kwargs: The generation kwargs.
- """
- super().__init__(truncate, generation_kwargs)
- # We pop the model_max_length as it is not sent to the model
- # but used to truncate the prompt if needed
- # Llama 2 has context window size of 4096 tokens
- # with some exceptions when the context window has been extended
- model_max_length = self.generation_kwargs.pop("model_max_length", 4096)
-
- # Use `google/flan-t5-base` as it's also BPE sentencepiece tokenizer just like llama 2
- # a) we should get good estimates for the prompt length (empirically close to llama 2)
- # b) we can use apply_chat_template with the template above to delineate ChatMessages
- tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
- tokenizer.bos_token = ""
- tokenizer.eos_token = ""
- tokenizer.unk_token = ""
- self.prompt_handler = DefaultPromptHandler(
- tokenizer=tokenizer,
- model_max_length=model_max_length,
- max_length=self.generation_kwargs.get("max_gen_len") or 512,
- )
-
- def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]:
- """
- Prepares the body for the Meta Llama 2 request.
-
- :param messages: The chat messages to package into the request.
- :param inference_kwargs: Additional inference kwargs to use.
- """
- default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512}
-
- # no support for stop words in Meta Llama 2
- params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS)
- body = {"prompt": self.prepare_chat_messages(messages=messages), **params}
- return body
-
- def prepare_chat_messages(self, messages: List[ChatMessage]) -> str:
- """
- Prepares the chat messages for the Meta Llama 2 request.
-
- :param messages: The chat messages to prepare.
- :returns: The prepared chat messages as a string ready for the model.
- """
- prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template(
- conversation=messages, tokenize=False, chat_template=self.chat_template
- )
-
- if self.truncate:
- prepared_prompt = self._ensure_token_limit(prepared_prompt)
- return prepared_prompt
-
- def check_prompt(self, prompt: str) -> Dict[str, Any]:
- """
- Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated.
-
- :param prompt: The prompt to check.
- :returns: A dictionary containing the resized prompt and additional information.
-
- """
- return self.prompt_handler(prompt)
-
- def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
- """
- Extracts the messages from the response body.
-
- :param response_body: The response body.
- :return: The extracted ChatMessage list.
- """
- message_tag = "generation"
- metadata = {k: v for (k, v) in response_body.items() if k != message_tag}
- return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)]
-
- def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk:
- """
- Extracts the content and meta from a streaming chunk.
-
- :param chunk: The streaming chunk as dict.
- :returns: A StreamingChunk object.
- """
- return StreamingChunk(content=chunk.get("generation", ""), meta=chunk)
diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py
index 183198bce..499fe1c24 100644
--- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py
+++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py
@@ -1,12 +1,12 @@
import json
import logging
-import re
-from typing import Any, Callable, ClassVar, Dict, List, Optional, Type
+from typing import Any, Callable, Dict, List, Optional
from botocore.config import Config
+from botocore.eventstream import EventStream
from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict
-from haystack.dataclasses import ChatMessage, StreamingChunk
+from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk
from haystack.utils.auth import Secret, deserialize_secrets_inplace
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
@@ -16,18 +16,16 @@
)
from haystack_integrations.common.amazon_bedrock.utils import get_aws_session
-from .adapters import AnthropicClaudeChatAdapter, BedrockModelChatAdapter, MetaLlama2ChatAdapter, MistralChatAdapter
-
logger = logging.getLogger(__name__)
@component
class AmazonBedrockChatGenerator:
"""
- Completes chats using LLMs hosted on Amazon Bedrock.
+ Completes chats using LLMs hosted on Amazon Bedrock available via the Bedrock Converse API.
For example, to use the Anthropic Claude 3 Sonnet model, initialize this component with the
- 'anthropic.claude-3-sonnet-20240229-v1:0' model name.
+ 'anthropic.claude-3-5-sonnet-20240620-v1:0' model name.
### Usage example
@@ -40,7 +38,7 @@ class AmazonBedrockChatGenerator:
ChatMessage.from_user("What's Natural Language Processing?")]
- client = AmazonBedrockChatGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0",
+ client = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0",
streaming_callback=print_streaming_chunk)
client.run(messages, generation_kwargs={"max_tokens": 512})
@@ -58,12 +56,6 @@ class AmazonBedrockChatGenerator:
supports Amazon Bedrock.
"""
- SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = {
- r"([a-z]{2}\.)?anthropic.claude.*": AnthropicClaudeChatAdapter,
- r"([a-z]{2}\.)?meta.llama2.*": MetaLlama2ChatAdapter,
- r"([a-z]{2}\.)?mistral.*": MistralChatAdapter,
- }
-
def __init__(
self,
model: str,
@@ -77,7 +69,6 @@ def __init__(
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
- truncate: Optional[bool] = True,
boto3_config: Optional[Dict[str, Any]] = None,
):
"""
@@ -111,7 +102,6 @@ def __init__(
function that handles the streaming chunks. The callback function receives a
[StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and
switches the streaming mode on.
- :param truncate: Whether to truncate the prompt messages or not.
:param boto3_config: The configuration for the boto3 client.
:raises ValueError: If the model name is empty or None.
@@ -129,17 +119,8 @@ def __init__(
self.aws_profile_name = aws_profile_name
self.stop_words = stop_words or []
self.streaming_callback = streaming_callback
- self.truncate = truncate
self.boto3_config = boto3_config
- # get the model adapter for the given model
- model_adapter_cls = self.get_model_adapter(model=model)
- if not model_adapter_cls:
- msg = f"AmazonBedrockGenerator doesn't support the model {model}."
- raise AmazonBedrockConfigurationError(msg)
- self.model_adapter = model_adapter_cls(self.truncate, generation_kwargs or {})
-
- # create the AWS session and client
def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
return secret.resolve_value() if secret else None
@@ -162,89 +143,9 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
)
raise AmazonBedrockConfigurationError(msg) from exception
- @component.output_types(replies=List[ChatMessage])
- def run(
- self,
- messages: List[ChatMessage],
- streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
- generation_kwargs: Optional[Dict[str, Any]] = None,
- ):
- """
- Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM.
-
- :param messages: The messages to generate a response to.
- :param streaming_callback:
- A callback function that is called when a new token is received from the stream.
- :param generation_kwargs: Additional generation keyword arguments passed to the model.
- :returns: A dictionary with the following keys:
- - `replies`: The generated List of `ChatMessage` objects.
- """
- generation_kwargs = generation_kwargs or {}
- generation_kwargs = generation_kwargs.copy()
-
- streaming_callback = streaming_callback or self.streaming_callback
- generation_kwargs["stream"] = streaming_callback is not None
-
- # check if the prompt is a list of ChatMessage objects
- if not (
- isinstance(messages, list)
- and len(messages) > 0
- and all(isinstance(message, ChatMessage) for message in messages)
- ):
- msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt."
- raise ValueError(msg)
-
- body = self.model_adapter.prepare_body(
- messages=messages, **{"stop_words": self.stop_words, **generation_kwargs}
- )
- try:
- if streaming_callback:
- response = self.client.invoke_model_with_response_stream(
- body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json"
- )
- response_stream = response["body"]
- replies = self.model_adapter.get_stream_responses(
- stream=response_stream, streaming_callback=streaming_callback
- )
- else:
- response = self.client.invoke_model(
- body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json"
- )
- response_body = json.loads(response.get("body").read().decode("utf-8"))
- replies = self.model_adapter.get_responses(response_body=response_body)
- except ClientError as exception:
- msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}"
- raise AmazonBedrockInferenceError(msg) from exception
-
- # rename the meta key to be inline with OpenAI meta output keys
- for response in replies:
- if response.meta:
- if "usage" in response.meta:
- if "input_tokens" in response.meta["usage"]:
- response.meta["usage"]["prompt_tokens"] = response.meta["usage"].pop("input_tokens")
- if "output_tokens" in response.meta["usage"]:
- response.meta["usage"]["completion_tokens"] = response.meta["usage"].pop("output_tokens")
- else:
- response.meta["usage"] = {}
- if "prompt_token_count" in response.meta:
- response.meta["usage"]["prompt_tokens"] = response.meta.pop("prompt_token_count")
- if "generation_token_count" in response.meta:
- response.meta["usage"]["completion_tokens"] = response.meta.pop("generation_token_count")
-
- return {"replies": replies}
-
- @classmethod
- def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]]:
- """
- Returns the model adapter for the given model.
-
- :param model: The model to get the adapter for.
- :returns: The model adapter for the given model, or None if the model is not supported.
- """
- for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items():
- if re.fullmatch(pattern, model):
- return adapter
- return None
+ self.generation_kwargs = generation_kwargs or {}
+ self.stop_words = stop_words or []
+ self.streaming_callback = streaming_callback
def to_dict(self) -> Dict[str, Any]:
"""
@@ -263,9 +164,8 @@ def to_dict(self) -> Dict[str, Any]:
aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
model=self.model,
stop_words=self.stop_words,
- generation_kwargs=self.model_adapter.generation_kwargs,
+ generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
- truncate=self.truncate,
boto3_config=self.boto3_config,
)
@@ -274,10 +174,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator":
"""
Deserializes the component from a dictionary.
- :param data:
- Dictionary to deserialize from.
+ :param data: Dictionary with serialized data.
:returns:
- Deserialized component.
+ Instance of `AmazonBedrockChatGenerator`.
"""
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
@@ -288,3 +187,172 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator":
["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"],
)
return default_from_dict(cls, data)
+
+ @component.output_types(replies=List[ChatMessage])
+ def run(
+ self,
+ messages: List[ChatMessage],
+ streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
+ generation_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ generation_kwargs = generation_kwargs or {}
+
+ # Merge generation_kwargs with defaults
+ merged_kwargs = self.generation_kwargs.copy()
+ merged_kwargs.update(generation_kwargs)
+
+ # Extract known inference parameters
+ # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InferenceConfiguration.html
+ inference_config = {
+ key: merged_kwargs.pop(key, None)
+ for key in ["maxTokens", "stopSequences", "temperature", "topP"]
+ if key in merged_kwargs
+ }
+
+ # Extract tool configuration if present
+ # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html
+ tool_config = merged_kwargs.pop("toolConfig", None)
+
+ # Any remaining kwargs go to additionalModelRequestFields
+ additional_fields = merged_kwargs if merged_kwargs else None
+
+ # Prepare system prompts and messages
+ system_prompts = []
+ if messages and messages[0].is_from(ChatRole.SYSTEM):
+ system_prompts = [{"text": messages[0].text}]
+ messages = messages[1:]
+
+ messages_list = [{"role": msg.role.value, "content": [{"text": msg.text}]} for msg in messages]
+
+ # Build API parameters
+ params = {
+ "modelId": self.model,
+ "messages": messages_list,
+ "system": system_prompts,
+ "inferenceConfig": inference_config,
+ }
+ if tool_config:
+ params["toolConfig"] = tool_config
+ if additional_fields:
+ params["additionalModelRequestFields"] = additional_fields
+
+ callback = streaming_callback or self.streaming_callback
+
+ try:
+ if callback:
+ response = self.client.converse_stream(**params)
+ response_stream: EventStream = response.get("stream")
+ if not response_stream:
+ msg = "No stream found in the response."
+ raise AmazonBedrockInferenceError(msg)
+ replies = self.process_streaming_response(response_stream, callback)
+ else:
+ response = self.client.converse(**params)
+ replies = self.extract_replies_from_response(response)
+ except ClientError as exception:
+ msg = f"Could not generate inference for Amazon Bedrock model {self.model} due: {exception}"
+ raise AmazonBedrockInferenceError(msg) from exception
+
+ return {"replies": replies}
+
+ def extract_replies_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]:
+ replies = []
+ if "output" in response_body and "message" in response_body["output"]:
+ message = response_body["output"]["message"]
+ if message["role"] == "assistant":
+ content_blocks = message["content"]
+
+ # Common meta information
+ base_meta = {
+ "model": self.model,
+ "index": 0,
+ "finish_reason": response_body.get("stopReason"),
+ "usage": {
+ # OpenAI's format for usage for cross ChatGenerator compatibility
+ "prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0),
+ "completion_tokens": response_body.get("usage", {}).get("outputTokens", 0),
+ "total_tokens": response_body.get("usage", {}).get("totalTokens", 0),
+ },
+ }
+
+ # Process each content block separately
+ for content_block in content_blocks:
+ if "text" in content_block:
+ replies.append(ChatMessage.from_assistant(content=content_block["text"], meta=base_meta.copy()))
+ elif "toolUse" in content_block:
+ replies.append(
+ ChatMessage.from_assistant(
+ content=json.dumps(content_block["toolUse"]), meta=base_meta.copy()
+ )
+ )
+ return replies
+
+ def process_streaming_response(
+ self, response_stream: EventStream, streaming_callback: Callable[[StreamingChunk], None]
+ ) -> List[ChatMessage]:
+ replies = []
+ current_content = ""
+ current_tool_use = None
+ base_meta = {
+ "model": self.model,
+ "index": 0,
+ }
+
+ for event in response_stream:
+ if "contentBlockStart" in event:
+ # Reset accumulators for new message
+ current_content = ""
+ current_tool_use = None
+ block_start = event["contentBlockStart"]
+ if "start" in block_start and "toolUse" in block_start["start"]:
+ tool_start = block_start["start"]["toolUse"]
+ current_tool_use = {
+ "toolUseId": tool_start["toolUseId"],
+ "name": tool_start["name"],
+ "input": "", # Will accumulate deltas as string
+ }
+
+ elif "contentBlockDelta" in event:
+ delta = event["contentBlockDelta"]["delta"]
+ if "text" in delta:
+ delta_text = delta["text"]
+ current_content += delta_text
+ streaming_chunk = StreamingChunk(content=delta_text, meta=None)
+ # it only makes sense to call callback on text deltas
+ streaming_callback(streaming_chunk)
+ elif "toolUse" in delta and current_tool_use:
+ # Accumulate tool use input deltas
+ current_tool_use["input"] += delta["toolUse"].get("input", "")
+ elif "contentBlockStop" in event:
+ if current_tool_use:
+ # Parse accumulated input if it's a JSON string
+ try:
+ input_json = json.loads(current_tool_use["input"])
+ current_tool_use["input"] = input_json
+ except json.JSONDecodeError:
+ # Keep as string if not valid JSON
+ pass
+
+ tool_content = json.dumps(current_tool_use)
+ replies.append(ChatMessage.from_assistant(content=tool_content, meta=base_meta.copy()))
+ elif current_content:
+ replies.append(ChatMessage.from_assistant(content=current_content, meta=base_meta.copy()))
+
+ elif "messageStop" in event:
+ # not 100% correct for multiple messages but no way around it
+ for reply in replies:
+ reply.meta["finish_reason"] = event["messageStop"].get("stopReason")
+
+ elif "metadata" in event:
+ metadata = event["metadata"]
+ # not 100% correct for multiple messages but no way around it
+ for reply in replies:
+ if "usage" in metadata:
+ usage = metadata["usage"]
+ reply.meta["usage"] = {
+ "prompt_tokens": usage.get("inputTokens", 0),
+ "completion_tokens": usage.get("outputTokens", 0),
+ "total_tokens": usage.get("totalTokens", 0),
+ }
+
+ return replies
diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py
index 8d6a5c3ee..8eb29729c 100644
--- a/integrations/amazon_bedrock/tests/test_chat_generator.py
+++ b/integrations/amazon_bedrock/tests/test_chat_generator.py
@@ -1,30 +1,36 @@
import json
-import logging
-import os
-from typing import Any, Dict, Optional, Type
-from unittest.mock import MagicMock, patch
+from typing import Any, Dict, Optional
import pytest
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk
from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator
-from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import (
- AnthropicClaudeChatAdapter,
- BedrockModelChatAdapter,
- MetaLlama2ChatAdapter,
- MistralChatAdapter,
-)
KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator"
-MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"]
-MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-haiku-20240307-v1:0"]
-MISTRAL_MODELS = [
- "mistral.mistral-7b-instruct-v0:2",
- "mistral.mixtral-8x7b-instruct-v0:1",
+MODELS_TO_TEST = [
+ "anthropic.claude-3-5-sonnet-20240620-v1:0",
+ "cohere.command-r-plus-v1:0",
+ "mistral.mistral-large-2402-v1:0",
+]
+MODELS_TO_TEST_WITH_TOOLS = [
+ "anthropic.claude-3-5-sonnet-20240620-v1:0",
+ "cohere.command-r-plus-v1:0",
"mistral.mistral-large-2402-v1:0",
]
+# so far we've discovered these models support streaming and tool use
+STREAMING_TOOL_MODELS = ["anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-r-plus-v1:0"]
+
+
+@pytest.fixture
+def chat_messages():
+ messages = [
+ ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."),
+ ChatMessage.from_user("What's the capital of France?"),
+ ]
+ return messages
+
@pytest.mark.parametrize(
"boto3_config",
@@ -35,12 +41,12 @@
},
],
)
-def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]):
+def test_to_dict(mock_boto3_session, boto3_config):
"""
Test that the to_dict method returns the correct dictionary without aws credentials
"""
generator = AmazonBedrockChatGenerator(
- model="anthropic.claude-v2",
+ model="cohere.command-r-plus-v1:0",
generation_kwargs={"temperature": 0.7},
streaming_callback=print_streaming_chunk,
boto3_config=boto3_config,
@@ -53,11 +59,10 @@ def test_to_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any]]
"aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
"aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
- "model": "anthropic.claude-v2",
+ "model": "cohere.command-r-plus-v1:0",
"generation_kwargs": {"temperature": 0.7},
"stop_words": [],
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
- "truncate": True,
"boto3_config": boto3_config,
},
}
@@ -87,16 +92,14 @@ def test_from_dict(mock_boto3_session: Any, boto3_config: Optional[Dict[str, Any
"aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False},
"aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False},
"aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False},
- "model": "anthropic.claude-v2",
+ "model": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"generation_kwargs": {"temperature": 0.7},
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
- "truncate": True,
"boto3_config": boto3_config,
},
}
)
- assert generator.model == "anthropic.claude-v2"
- assert generator.model_adapter.generation_kwargs == {"temperature": 0.7}
+ assert generator.model == "anthropic.claude-3-5-sonnet-20240620-v1:0"
assert generator.streaming_callback == print_streaming_chunk
assert generator.boto3_config == boto3_config
@@ -107,13 +110,10 @@ def test_default_constructor(mock_boto3_session, set_env_variables):
"""
layer = AmazonBedrockChatGenerator(
- model="anthropic.claude-v2",
+ model="anthropic.claude-3-5-sonnet-20240620-v1:0",
)
- assert layer.model == "anthropic.claude-v2"
- assert layer.truncate is True
- assert layer.model_adapter.prompt_handler is not None
- assert layer.model_adapter.prompt_handler.model_max_length == 100000
+ assert layer.model == "anthropic.claude-3-5-sonnet-20240620-v1:0"
# assert mocked boto3 client called exactly once
mock_boto3_session.assert_called_once()
@@ -134,18 +134,10 @@ def test_constructor_with_generation_kwargs(mock_boto3_session):
"""
generation_kwargs = {"temperature": 0.7}
- layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", generation_kwargs=generation_kwargs)
- assert "temperature" in layer.model_adapter.generation_kwargs
- assert layer.model_adapter.generation_kwargs["temperature"] == 0.7
- assert layer.model_adapter.truncate is True
-
-
-def test_constructor_with_truncate(mock_boto3_session):
- """
- Test that truncate param is correctly set in the model constructor
- """
- layer = AmazonBedrockChatGenerator(model="anthropic.claude-v2", truncate=False)
- assert layer.model_adapter.truncate is False
+ layer = AmazonBedrockChatGenerator(
+ model="anthropic.claude-3-5-sonnet-20240620-v1:0", generation_kwargs=generation_kwargs
+ )
+ assert layer.generation_kwargs == generation_kwargs
def test_constructor_with_empty_model():
@@ -156,208 +148,15 @@ def test_constructor_with_empty_model():
AmazonBedrockChatGenerator(model="")
-def test_short_prompt_is_not_truncated(mock_boto3_session):
- """
- Test that a short prompt is not truncated
- """
- # Define a short mock prompt and its tokenized version
- mock_prompt_text = "I am a tokenized prompt"
- mock_prompt_tokens = mock_prompt_text.split()
-
- # Mock the tokenizer so it returns our predefined tokens
- mock_tokenizer = MagicMock()
- mock_tokenizer.tokenize.return_value = mock_prompt_tokens
-
- # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens
- # Since our mock prompt is 5 tokens long, it doesn't exceed the
- # total limit (5 prompt tokens + 3 generated tokens < 10 tokens)
- max_length_generated_text = 3
- total_model_max_length = 10
-
- with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer):
- layer = AmazonBedrockChatGenerator(
- "anthropic.claude-v2",
- generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text},
- )
- prompt_after_resize = layer.model_adapter._ensure_token_limit(mock_prompt_text)
-
- # The prompt doesn't exceed the limit, _ensure_token_limit doesn't truncate it
- assert prompt_after_resize == mock_prompt_text
-
-
-def test_long_prompt_is_truncated(mock_boto3_session):
- """
- Test that a long prompt is truncated
- """
- # Define a long mock prompt and its tokenized version
- long_prompt_text = "I am a tokenized prompt of length eight"
- long_prompt_tokens = long_prompt_text.split()
-
- # _ensure_token_limit will truncate the prompt to make it fit into the model's max token limit
- truncated_prompt_text = "I am a tokenized prompt of length"
-
- # Mock the tokenizer to return our predefined tokens
- # convert tokens to our predefined truncated text
- mock_tokenizer = MagicMock()
- mock_tokenizer.tokenize.return_value = long_prompt_tokens
- mock_tokenizer.convert_tokens_to_string.return_value = truncated_prompt_text
-
- # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens
- # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens)
- max_length_generated_text = 3
- total_model_max_length = 10
-
- with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer):
- layer = AmazonBedrockChatGenerator(
- "anthropic.claude-v2",
- generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text},
- )
- prompt_after_resize = layer.model_adapter._ensure_token_limit(long_prompt_text)
+class TestAmazonBedrockChatGeneratorInference:
- # The prompt exceeds the limit, _ensure_token_limit truncates it
- assert prompt_after_resize == truncated_prompt_text
-
-
-def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
- """
- Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False
- """
- messages = [ChatMessage.from_user("What is the biggest city in United States?")]
-
- # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens)
- max_length_generated_text = 3
- total_model_max_length = 10
-
- with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()):
- generator = AmazonBedrockChatGenerator(
- model="anthropic.claude-v2",
- truncate=False,
- generation_kwargs={"model_max_length": total_model_max_length, "max_tokens": max_length_generated_text},
- )
-
- # Mock the _ensure_token_limit method to track if it is called
- with patch.object(
- generator.model_adapter, "_ensure_token_limit", wraps=generator.model_adapter._ensure_token_limit
- ) as mock_ensure_token_limit:
- # Mock the model adapter to avoid actual invocation
- generator.model_adapter.prepare_body = MagicMock(return_value={})
- generator.client = MagicMock()
- generator.client.invoke_model = MagicMock(
- return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))}
- )
-
- generator.model_adapter.get_responses = MagicMock(
- return_value=[
- ChatMessage.from_assistant(
- content="Some text",
- meta={
- "model": "claude-3-sonnet-20240229",
- "index": 0,
- "finish_reason": "end_turn",
- "usage": {"prompt_tokens": 16, "completion_tokens": 55},
- },
- )
- ]
- )
- # Invoke the generator
- generator.run(messages=messages)
-
- # Ensure _ensure_token_limit was not called
- mock_ensure_token_limit.assert_not_called()
-
- # Check the prompt passed to prepare_body
- generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[], stream=False)
-
-
-@pytest.mark.parametrize(
- "model, expected_model_adapter",
- [
- ("anthropic.claude-v1", AnthropicClaudeChatAdapter),
- ("anthropic.claude-v2", AnthropicClaudeChatAdapter),
- ("eu.anthropic.claude-v1", AnthropicClaudeChatAdapter), # cross-region inference
- ("us.anthropic.claude-v2", AnthropicClaudeChatAdapter), # cross-region inference
- ("anthropic.claude-instant-v1", AnthropicClaudeChatAdapter),
- ("anthropic.claude-super-v5", AnthropicClaudeChatAdapter), # artificial
- ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter),
- ("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter),
- ("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial
- ("us.meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference
- ("eu.meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference
- ("de.meta.llama2-130b-v5", MetaLlama2ChatAdapter), # cross-region inference
- ("unknown_model", None),
- ],
-)
-def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[BedrockModelChatAdapter]]):
- """
- Test that the correct model adapter is returned for a given model
- """
- model_adapter = AmazonBedrockChatGenerator.get_model_adapter(model=model)
- assert model_adapter == expected_model_adapter
-
-
-class TestAnthropicClaudeAdapter:
- def test_prepare_body_with_default_params(self) -> None:
- layer = AnthropicClaudeChatAdapter(truncate=True, generation_kwargs={})
- prompt = "Hello, how are you?"
- expected_body = {
- "anthropic_version": "bedrock-2023-05-31",
- "max_tokens": 512,
- "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}],
- }
-
- body = layer.prepare_body([ChatMessage.from_user(prompt)])
-
- assert body == expected_body
-
- def test_prepare_body_with_custom_inference_params(self) -> None:
- layer = AnthropicClaudeChatAdapter(
- truncate=True, generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}
- )
- prompt = "Hello, how are you?"
- expected_body = {
- "anthropic_version": "bedrock-2023-05-31",
- "max_tokens": 512,
- "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}],
- "stop_sequences": ["CUSTOM_STOP"],
- "temperature": 0.7,
- "top_k": 5,
- "top_p": 0.8,
- }
-
- body = layer.prepare_body(
- [ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69, stop_sequences=["CUSTOM_STOP"]
- )
-
- assert body == expected_body
-
- @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS)
+ @pytest.mark.parametrize("model_name", MODELS_TO_TEST)
@pytest.mark.integration
- def test_tools_use(self, model_name):
- """
- Test function calling with AWS Bedrock Anthropic adapter
- """
- # See https://docs.anthropic.com/en/docs/tool-use for more information
- tools = [
- {
- "name": "top_song",
- "description": "Get the most popular song played on a radio station.",
- "input_schema": {
- "type": "object",
- "properties": {
- "sign": {
- "type": "string",
- "description": "The call sign for the radio station for which you want the most popular"
- " song. Example calls signs are WZPZ and WKRP.",
- }
- },
- "required": ["sign"],
- },
- }
- ]
- messages = []
- messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?"))
+ def test_default_inference_params(self, model_name, chat_messages):
client = AmazonBedrockChatGenerator(model=model_name)
- response = client.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": {"type": "any"}})
+ response = client.run(chat_messages)
+
+ assert "replies" in response, "Response does not contain 'replies' key"
replies = response["replies"]
assert isinstance(replies, list), "Replies is not a list"
assert len(replies) > 0, "No replies received"
@@ -366,102 +165,32 @@ def test_tools_use(self, model_name):
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
- assert "top_song" in first_reply.content.lower(), "First reply does not contain top_song"
+ assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
assert first_reply.meta, "First reply has no metadata"
- fc_response = json.loads(first_reply.content)
- assert "name" in fc_response, "First reply does not contain name of the tool"
- assert "input" in fc_response, "First reply does not contain input of the tool"
-
-
-class TestMistralAdapter:
- def test_prepare_body_with_default_params(self) -> None:
- layer = MistralChatAdapter(truncate=True, generation_kwargs={})
- prompt = "Hello, how are you?"
- expected_body = {
- "max_tokens": 512,
- "prompt": "[INST] Hello, how are you? [/INST]",
- }
- body = layer.prepare_body([ChatMessage.from_user(prompt)])
+ if first_reply.meta and "usage" in first_reply.meta:
+ assert "prompt_tokens" in first_reply.meta["usage"]
+ assert "completion_tokens" in first_reply.meta["usage"]
- assert body == expected_body
+ @pytest.mark.parametrize("model_name", MODELS_TO_TEST)
+ @pytest.mark.integration
+ def test_default_inference_with_streaming(self, model_name, chat_messages):
+ streaming_callback_called = False
+ paris_found_in_response = False
- def test_prepare_body_with_custom_inference_params(self) -> None:
- layer = MistralChatAdapter(truncate=True, generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4})
- prompt = "Hello, how are you?"
- expected_body = {
- "prompt": "[INST] Hello, how are you? [/INST]",
- "max_tokens": 512,
- "temperature": 0.7,
- "top_p": 0.8,
- }
+ def streaming_callback(chunk: StreamingChunk):
+ nonlocal streaming_callback_called, paris_found_in_response
+ streaming_callback_called = True
+ assert isinstance(chunk, StreamingChunk)
+ assert chunk.content is not None
+ if not paris_found_in_response:
+ paris_found_in_response = "paris" in chunk.content.lower()
- body = layer.prepare_body([ChatMessage.from_user(prompt)], top_p=0.8, top_k=5, max_tokens_to_sample=69)
-
- assert body == expected_body
-
- def test_mistral_chat_template_correct_order(self):
- layer = MistralChatAdapter(truncate=True, generation_kwargs={})
- layer.prepare_body([ChatMessage.from_user("A"), ChatMessage.from_assistant("B"), ChatMessage.from_user("C")])
- layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_user("B"), ChatMessage.from_assistant("C")])
-
- def test_mistral_chat_template_incorrect_order(self):
- layer = MistralChatAdapter(truncate=True, generation_kwargs={})
- try:
- layer.prepare_body([ChatMessage.from_assistant("B"), ChatMessage.from_assistant("C")])
- msg = "Expected TemplateError"
- raise AssertionError(msg)
- except Exception as e:
- assert "Conversation roles must alternate user/assistant/" in str(e)
-
- try:
- layer.prepare_body([ChatMessage.from_user("A"), ChatMessage.from_user("B")])
- msg = "Expected TemplateError"
- raise AssertionError(msg)
- except Exception as e:
- assert "Conversation roles must alternate user/assistant/" in str(e)
-
- try:
- layer.prepare_body([ChatMessage.from_system("A"), ChatMessage.from_system("B")])
- msg = "Expected TemplateError"
- raise AssertionError(msg)
- except Exception as e:
- assert "Conversation roles must alternate user/assistant/" in str(e)
-
- def test_use_mistral_adapter_without_hf_token(self, monkeypatch, caplog) -> None:
- monkeypatch.delenv("HF_TOKEN", raising=False)
- with (
- patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained,
- patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"),
- caplog.at_level(logging.WARNING),
- ):
- MistralChatAdapter(truncate=True, generation_kwargs={})
- mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf")
- assert "no HF_TOKEN was found" in caplog.text
-
- def test_use_mistral_adapter_with_hf_token(self, monkeypatch) -> None:
- monkeypatch.setenv("HF_TOKEN", "test")
- with (
- patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained,
- patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"),
- ):
- MistralChatAdapter(truncate=True, generation_kwargs={})
- mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1")
-
- @pytest.mark.skipif(
- not os.environ.get("HF_API_TOKEN", None),
- reason=(
- "To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also "
- "have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`"
- ),
- )
- @pytest.mark.parametrize("model_name", MISTRAL_MODELS)
- @pytest.mark.integration
- def test_default_inference_params(self, model_name, chat_messages):
- client = AmazonBedrockChatGenerator(model=model_name)
+ client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback)
response = client.run(chat_messages)
- assert "replies" in response, "Response does not contain 'replies' key"
+ assert streaming_callback_called, "Streaming callback was not called"
+ assert paris_found_in_response, "The streaming callback response did not contain 'paris'"
replies = response["replies"]
assert isinstance(replies, list), "Replies is not a list"
assert len(replies) > 0, "No replies received"
@@ -473,77 +202,44 @@ def test_default_inference_params(self, model_name, chat_messages):
assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
assert first_reply.meta, "First reply has no metadata"
-
-@pytest.fixture
-def chat_messages():
- messages = [
- ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."),
- ChatMessage.from_user("What's the capital of France?"),
- ]
- return messages
-
-
-class TestMetaLlama2ChatAdapter:
- @pytest.mark.integration
- def test_prepare_body_with_default_params(self) -> None:
- # leave this test as integration because we really need only tokenizer from HF
- # that way we can ensure prompt chat message formatting
- layer = MetaLlama2ChatAdapter(truncate=True, generation_kwargs={})
- prompt = "Hello, how are you?"
- expected_body = {"prompt": "[INST] Hello, how are you? [/INST]", "max_gen_len": 512}
-
- body = layer.prepare_body([ChatMessage.from_user(prompt)])
-
- assert body == expected_body
-
+ @pytest.mark.parametrize("model_name", MODELS_TO_TEST_WITH_TOOLS)
@pytest.mark.integration
- def test_prepare_body_with_custom_inference_params(self) -> None:
- # leave this test as integration because we really need only tokenizer from HF
- # that way we can ensure prompt chat message formatting
- layer = MetaLlama2ChatAdapter(
- truncate=True,
- generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]},
- )
- prompt = "Hello, how are you?"
-
- # expected body is different because stop_sequences and top_k are not supported by MetaLlama2
- expected_body = {
- "prompt": "[INST] Hello, how are you? [/INST]",
- "max_gen_len": 69,
- "temperature": 0.7,
- "top_p": 0.8,
+ def test_tools_use(self, model_name):
+ """
+ Test function calling with AWS Bedrock Anthropic adapter
+ """
+ # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html
+ tool_config = {
+ "tools": [
+ {
+ "toolSpec": {
+ "name": "top_song",
+ "description": "Get the most popular song played on a radio station.",
+ "inputSchema": {
+ "json": {
+ "type": "object",
+ "properties": {
+ "sign": {
+ "type": "string",
+ "description": "The call sign for the radio station "
+ "for which you want the most popular song. "
+ "Example calls signs are WZPZ and WKRP.",
+ }
+ },
+ "required": ["sign"],
+ }
+ },
+ }
+ }
+ ],
+ # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
+ "toolChoice": {"auto": {}},
}
- body = layer.prepare_body(
- [ChatMessage.from_user(prompt)],
- temperature=0.7,
- top_p=0.8,
- top_k=5,
- max_gen_len=69,
- stop_sequences=["CUSTOM_STOP"],
- )
-
- assert body == expected_body
-
- @pytest.mark.integration
- def test_get_responses(self) -> None:
- adapter = MetaLlama2ChatAdapter(truncate=True, generation_kwargs={})
- response_body = {"generation": "This is a single response."}
- expected_response = "This is a single response."
- response_message = adapter.get_responses(response_body)
- # assert that the type of each item in the list is a ChatMessage
- for message in response_message:
- assert isinstance(message, ChatMessage)
-
- assert response_message == [ChatMessage.from_assistant(expected_response)]
-
- @pytest.mark.parametrize("model_name", MODELS_TO_TEST)
- @pytest.mark.integration
- def test_default_inference_params(self, model_name, chat_messages):
+ messages = []
+ messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?"))
client = AmazonBedrockChatGenerator(model=model_name)
- response = client.run(chat_messages)
-
- assert "replies" in response, "Response does not contain 'replies' key"
+ response = client.run(messages=messages, generation_kwargs={"toolConfig": tool_config})
replies = response["replies"]
assert isinstance(replies, list), "Replies is not a list"
assert len(replies) > 0, "No replies received"
@@ -552,32 +248,70 @@ def test_default_inference_params(self, model_name, chat_messages):
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
- assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
assert first_reply.meta, "First reply has no metadata"
- if first_reply.meta and "usage" in first_reply.meta:
- assert "prompt_tokens" in first_reply.meta["usage"]
- assert "completion_tokens" in first_reply.meta["usage"]
-
- @pytest.mark.parametrize("model_name", MODELS_TO_TEST)
+ # Some models return thinking message as first and the second one as the tool call
+ if len(replies) > 1:
+ second_reply = replies[1]
+ assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance"
+ assert second_reply.content, "Second reply has no content"
+ assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant"
+ tool_call = json.loads(second_reply.content)
+ assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
+ assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
+ assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
+ assert (
+ tool_call["input"]["sign"] == "WZPZ"
+ ), f"Tool call {tool_call} does not contain the correct 'input' value"
+ else:
+ # case where the model returns the tool call as the first message
+ # double check that the tool call is correct
+ tool_call = json.loads(first_reply.content)
+ assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
+ assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
+ assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
+ assert (
+ tool_call["input"]["sign"] == "WZPZ"
+ ), f"Tool call {tool_call} does not contain the correct 'input' value"
+
+ @pytest.mark.parametrize("model_name", STREAMING_TOOL_MODELS)
@pytest.mark.integration
- def test_default_inference_with_streaming(self, model_name, chat_messages):
- streaming_callback_called = False
- paris_found_in_response = False
-
- def streaming_callback(chunk: StreamingChunk):
- nonlocal streaming_callback_called, paris_found_in_response
- streaming_callback_called = True
- assert isinstance(chunk, StreamingChunk)
- assert chunk.content is not None
- if not paris_found_in_response:
- paris_found_in_response = "paris" in chunk.content.lower()
-
- client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback)
- response = client.run(chat_messages)
+ def test_tools_use_with_streaming(self, model_name):
+ """
+ Test function calling with AWS Bedrock Anthropic adapter
+ """
+ # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolConfiguration.html
+ tool_config = {
+ "tools": [
+ {
+ "toolSpec": {
+ "name": "top_song",
+ "description": "Get the most popular song played on a radio station.",
+ "inputSchema": {
+ "json": {
+ "type": "object",
+ "properties": {
+ "sign": {
+ "type": "string",
+ "description": "The call sign for the radio station "
+ "for which you want the most popular song. Example "
+ "calls signs are WZPZ and WKRP.",
+ }
+ },
+ "required": ["sign"],
+ }
+ },
+ }
+ }
+ ],
+ # See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
+ "toolChoice": {"auto": {}},
+ }
- assert streaming_callback_called, "Streaming callback was not called"
- assert paris_found_in_response, "The streaming callback response did not contain 'paris'"
+ messages = []
+ messages.append(ChatMessage.from_user("What is the most popular song on WZPZ?"))
+ client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=print_streaming_chunk)
+ response = client.run(messages=messages, generation_kwargs={"toolConfig": tool_config})
replies = response["replies"]
assert isinstance(replies, list), "Replies is not a list"
assert len(replies) > 0, "No replies received"
@@ -586,5 +320,139 @@ def streaming_callback(chunk: StreamingChunk):
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
assert first_reply.content, "First reply has no content"
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
- assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
assert first_reply.meta, "First reply has no metadata"
+
+ # Some models return thinking message as first and the second one as the tool call
+ if len(replies) > 1:
+ second_reply = replies[1]
+ assert isinstance(second_reply, ChatMessage), "Second reply is not a ChatMessage instance"
+ assert second_reply.content, "Second reply has no content"
+ assert ChatMessage.is_from(second_reply, ChatRole.ASSISTANT), "Second reply is not from the assistant"
+ tool_call = json.loads(second_reply.content)
+ assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
+ assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
+ assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
+ assert (
+ tool_call["input"]["sign"] == "WZPZ"
+ ), f"Tool call {tool_call} does not contain the correct 'input' value"
+ else:
+ # case where the model returns the tool call as the first message
+ # double check that the tool call is correct
+ tool_call = json.loads(first_reply.content)
+ assert "toolUseId" in tool_call, "Tool call does not contain 'toolUseId' key"
+ assert tool_call["name"] == "top_song", f"Tool call {tool_call} does not contain the correct 'name' value"
+ assert "input" in tool_call, f"Tool call {tool_call} does not contain 'input' key"
+ assert (
+ tool_call["input"]["sign"] == "WZPZ"
+ ), f"Tool call {tool_call} does not contain the correct 'input' value"
+
+ def test_extract_replies_from_response(self, mock_boto3_session):
+ """
+ Test that extract_replies_from_response correctly processes both text and tool use responses
+ """
+ generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0")
+
+ # Test case 1: Simple text response
+ text_response = {
+ "output": {"message": {"role": "assistant", "content": [{"text": "This is a test response"}]}},
+ "stopReason": "complete",
+ "usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
+ }
+
+ replies = generator.extract_replies_from_response(text_response)
+ assert len(replies) == 1
+ assert replies[0].content == "This is a test response"
+ assert replies[0].role == ChatRole.ASSISTANT
+ assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0"
+ assert replies[0].meta["finish_reason"] == "complete"
+ assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}
+
+ # Test case 2: Tool use response
+ tool_response = {
+ "output": {
+ "message": {
+ "role": "assistant",
+ "content": [{"toolUse": {"toolUseId": "123", "name": "test_tool", "input": {"key": "value"}}}],
+ }
+ },
+ "stopReason": "tool_call",
+ "usage": {"inputTokens": 15, "outputTokens": 25, "totalTokens": 40},
+ }
+
+ replies = generator.extract_replies_from_response(tool_response)
+ assert len(replies) == 1
+ tool_content = json.loads(replies[0].content)
+ assert tool_content["toolUseId"] == "123"
+ assert tool_content["name"] == "test_tool"
+ assert tool_content["input"] == {"key": "value"}
+ assert replies[0].meta["finish_reason"] == "tool_call"
+ assert replies[0].meta["usage"] == {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40}
+
+ # Test case 3: Mixed content response
+ mixed_response = {
+ "output": {
+ "message": {
+ "role": "assistant",
+ "content": [
+ {"text": "Let me help you with that. I'll use the search tool to find the answer."},
+ {"toolUse": {"toolUseId": "456", "name": "search_tool", "input": {"query": "test"}}},
+ ],
+ }
+ },
+ "stopReason": "complete",
+ "usage": {"inputTokens": 25, "outputTokens": 35, "totalTokens": 60},
+ }
+
+ replies = generator.extract_replies_from_response(mixed_response)
+ assert len(replies) == 2
+ assert replies[0].content == "Let me help you with that. I'll use the search tool to find the answer."
+ tool_content = json.loads(replies[1].content)
+ assert tool_content["toolUseId"] == "456"
+ assert tool_content["name"] == "search_tool"
+ assert tool_content["input"] == {"query": "test"}
+
+ def test_process_streaming_response(self, mock_boto3_session):
+ """
+ Test that process_streaming_response correctly handles streaming events and accumulates responses
+ """
+ generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0")
+
+ streaming_chunks = []
+
+ def test_callback(chunk: StreamingChunk):
+ streaming_chunks.append(chunk)
+
+ # Simulate a stream of events for both text and tool use
+ events = [
+ {"contentBlockStart": {"start": {"text": ""}}},
+ {"contentBlockDelta": {"delta": {"text": "Let me "}}},
+ {"contentBlockDelta": {"delta": {"text": "help you."}}},
+ {"contentBlockStop": {}},
+ {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "search_tool"}}}},
+ {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"query":'}}}},
+ {"contentBlockDelta": {"delta": {"toolUse": {"input": '"test"}'}}}},
+ {"contentBlockStop": {}},
+ {"messageStop": {"stopReason": "complete"}},
+ {"metadata": {"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30}}},
+ ]
+
+ replies = generator.process_streaming_response(events, test_callback)
+
+ # Verify streaming chunks were received for text content
+ assert len(streaming_chunks) == 2
+ assert streaming_chunks[0].content == "Let me "
+ assert streaming_chunks[1].content == "help you."
+
+ # Verify final replies
+ assert len(replies) == 2
+ # Check text reply
+ assert replies[0].content == "Let me help you."
+ assert replies[0].meta["model"] == "anthropic.claude-3-5-sonnet-20240620-v1:0"
+ assert replies[0].meta["finish_reason"] == "complete"
+ assert replies[0].meta["usage"] == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}
+
+ # Check tool use reply
+ tool_content = json.loads(replies[1].content)
+ assert tool_content["toolUseId"] == "123"
+ assert tool_content["name"] == "search_tool"
+ assert tool_content["input"] == {"query": "test"}