From 9d1fb5e2fb51fc13f2da0cbfb155ad824518652c Mon Sep 17 00:00:00 2001 From: FloRul Date: Tue, 13 Aug 2024 13:10:49 -0400 Subject: [PATCH 01/35] copy from bedrock chat generator folder --- .../amazon_bedrock/converse/__init__.py | 3 + .../amazon_bedrock/converse/adapters.py | 565 ++++++++++++++++++ .../converse/converse_generator.py | 270 +++++++++ 3 files changed, 838 insertions(+) create mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/__init__.py create mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/adapters.py create mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/adapters.py new file mode 100644 index 000000000..cace89fe7 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/adapters.py @@ -0,0 +1,565 @@ +import json +import logging +import os +from abc import ABC, abstractmethod +from typing import Any, Callable, ClassVar, Dict, List, Optional + +from botocore.eventstream import EventStream +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from transformers import AutoTokenizer, PreTrainedTokenizer + +from haystack_integrations.components.generators.amazon_bedrock.handlers import DefaultPromptHandler + +logger = logging.getLogger(__name__) + + +class BedrockModelConverseAdapter(ABC): + """ + Base class for Amazon Bedrock chat model adapters. + + Each subclass of this class is designed to address the unique specificities of a particular chat LLM it adapts, + focusing on preparing the requests and extracting the responses from the Amazon Bedrock hosted chat LLMs. + """ + + def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: + """ + Initializes the chat adapter with the truncate parameter and generation kwargs. + """ + self.generation_kwargs = generation_kwargs + self.truncate = truncate + + @abstractmethod + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: + """ + Prepares the body for the Amazon Bedrock request. + Subclasses should override this method to package the chat messages into the request. + + :param messages: The chat messages to package into the request. + :param inference_kwargs: Additional inference kwargs to use. + :returns: The prepared body. + """ + + def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + """ + Extracts the responses from the Amazon Bedrock response. + + :param response_body: The response body. + :returns: The extracted responses. + """ + return self._extract_messages_from_response(response_body) + + def get_stream_responses( + self, stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] + ) -> List[ChatMessage]: + streaming_chunks: List[StreamingChunk] = [] + last_decoded_chunk: Dict[str, Any] = {} + for event in stream: + chunk = event.get("chunk") + if chunk: + last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) + streaming_chunk = self._build_streaming_chunk(last_decoded_chunk) + streaming_callback(streaming_chunk) # callback the stream handler with StreamingChunk + streaming_chunks.append(streaming_chunk) + responses = ["".join(chunk.content for chunk in streaming_chunks).lstrip()] + return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses] + + @staticmethod + def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any], allowed_params: List[str]) -> None: + """ + Updates target_dict with values from updates_dict. Merges lists instead of overriding them. + + :param target_dict: The dictionary to update. + :param updates_dict: The dictionary with updates. + :param allowed_params: The list of allowed params to use. + """ + for key, value in updates_dict.items(): + if key not in allowed_params: + logger.warning(f"Parameter '{key}' is not allowed and will be ignored.") + continue + if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list): + # Merge lists and remove duplicates + target_dict[key] = sorted(set(target_dict[key] + value)) + else: + # Override the value in target_dict + target_dict[key] = value + + def _get_params( + self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any], allowed_params: List[str] + ) -> Dict[str, Any]: + """ + Merges params from inference_kwargs with the default params and self.generation_kwargs. + Uses a helper function to merge lists or override values as necessary. + + :param inference_kwargs: The inference kwargs to merge. + :param default_params: The default params to start with. + :param allowed_params: The list of allowed params to use. + :returns: The merged params. + """ + # Start with a copy of default_params + kwargs = default_params.copy() + + # Update the default params with self.generation_kwargs and finally inference_kwargs + self._update_params(kwargs, self.generation_kwargs, allowed_params) + self._update_params(kwargs, inference_kwargs, allowed_params) + + return kwargs + + def _ensure_token_limit(self, prompt: str) -> str: + """ + Ensures that the prompt is within the token limit for the model. + :param prompt: The prompt to check. + :returns: The resized prompt. + """ + resize_info = self.check_prompt(prompt) + if resize_info["prompt_length"] != resize_info["new_prompt_length"]: + logger.warning( + "The prompt was truncated from %s tokens to %s tokens so that the prompt length and " + "the answer length (%s tokens) fit within the model's max token limit (%s tokens). " + "Shorten the prompt or it will be cut off.", + resize_info["prompt_length"], + max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore + resize_info["max_length"], + resize_info["model_max_length"], + ) + return str(resize_info["resized_prompt"]) + + @abstractmethod + def check_prompt(self, prompt: str) -> Dict[str, Any]: + """ + Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. + + :param prompt: The prompt to check. + :returns: A dictionary containing the resized prompt and additional information. + """ + + @abstractmethod + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + """ + Extracts the messages from the response body. + + :param response_body: The response body. + :returns: The extracted ChatMessage list. + """ + + @abstractmethod + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: + """ + Extracts the content and meta from a streaming chunk. + + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. + """ + + +class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): + """ + Model adapter for the Anthropic Claude chat model. + """ + + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + ALLOWED_PARAMS: ClassVar[List[str]] = [ + "anthropic_version", + "max_tokens", + "stop_sequences", + "temperature", + "top_p", + "top_k", + "system", + ] + + def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): + """ + Initializes the Anthropic Claude chat adapter. + + :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. + :param generation_kwargs: The generation kwargs. + """ + super().__init__(truncate, generation_kwargs) + + # We pop the model_max_length as it is not sent to the model + # but used to truncate the prompt if needed + # Anthropic Claude has a limit of at least 100000 tokens + # https://docs.anthropic.com/claude/reference/input-and-output-sizes + model_max_length = self.generation_kwargs.pop("model_max_length", 100000) + + # Truncate prompt if prompt tokens > model_max_length-max_length + # (max_length is the length of the generated text) + # TODO use Anthropic tokenizer to get the precise prompt length + # See https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#token-counting + self.prompt_handler = DefaultPromptHandler( + tokenizer="gpt2", + model_max_length=model_max_length, + max_length=self.generation_kwargs.get("max_tokens") or 512, + ) + + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: + """ + Prepares the body for the Anthropic Claude request. + + :param messages: The chat messages to package into the request. + :param inference_kwargs: Additional inference kwargs to use. + :returns: The prepared body. + """ + default_params = { + "anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31", + "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required + } + + # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it + stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) + if stop_sequences: + inference_kwargs["stop_sequences"] = stop_sequences + params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) + body = {**self.prepare_chat_messages(messages=messages), **params} + return body + + def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: + """ + Prepares the chat messages for the Anthropic Claude request. + + :param messages: The chat messages to prepare. + :returns: The prepared chat messages as a dictionary. + """ + body: Dict[str, Any] = {} + system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None + body["messages"] = [ + self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT) + ] + if system: + body["system"] = system + # Ensure token limit for each message in the body + if self.truncate: + for message in body["messages"]: + for content in message["content"]: + content["text"] = self._ensure_token_limit(content["text"]) + return body + + def check_prompt(self, prompt: str) -> Dict[str, Any]: + """ + Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. + + :param prompt: The prompt to check. + :returns: A dictionary containing the resized prompt and additional information. + """ + return self.prompt_handler(prompt) + + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + """ + Extracts the messages from the response body. + + :param response_body: The response body. + :return: The extracted ChatMessage list. + """ + messages: List[ChatMessage] = [] + if response_body.get("type") == "message": + for content in response_body["content"]: + if content.get("type") == "text": + meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} + messages.append(ChatMessage.from_assistant(content["text"], meta=meta)) + return messages + + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: + """ + Extracts the content and meta from a streaming chunk. + + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. + """ + if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta": + return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk) + return StreamingChunk(content="", meta=chunk) + + def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]: + """ + Convert a ChatMessage to a dictionary with the content and role fields. + :param m: The ChatMessage to convert. + :return: The dictionary with the content and role fields. + """ + return {"content": [{"type": "text", "text": m.content}], "role": m.role.value} + + +class MistralChatAdapter(BedrockModelChatAdapter): + """ + Model adapter for the Mistral chat model. + """ + + chat_template = """ + {% if messages[0]['role'] == 'system' %} + {% set loop_messages = messages[1:] %} + {% set system_message = messages[0]['content'] %} + {% else %} + {% set loop_messages = messages %} + {% set system_message = false %} + {% endif %} + {{bos_token}} + {% for message in loop_messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + {% if loop.index0 == 0 and system_message != false %} + {% set content = system_message + '\n' + message['content'] %} + {% else %} + {% set content = message['content'] %} + {% endif %} + {% if message['role'] == 'user' %} + {{ '[INST] ' + content.strip() + ' [/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ content.strip() + eos_token }} + {% endif %} + {% endfor %} + """ + chat_template = "".join(line.strip() for line in chat_template.splitlines()) + + # the above template was designed to match https://docs.mistral.ai/models/#chat-template + # and to support system messages, otherwise we could use the default mistral chat template + # available on HF infrastructure + + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + ALLOWED_PARAMS: ClassVar[List[str]] = [ + "max_tokens", + "safe_prompt", + "random_seed", + "temperature", + "top_p", + ] + + def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): + """ + Initializes the Mistral chat adapter. + :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. + :param generation_kwargs: The generation kwargs. + """ + super().__init__(truncate, generation_kwargs) + + # We pop the model_max_length as it is not sent to the model + # but used to truncate the prompt if needed + # Mistral has a limit of at least 32000 tokens + model_max_length = self.generation_kwargs.pop("model_max_length", 32000) + + # Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer + # a) we should get good estimates for the prompt length + # b) we can use apply_chat_template with the template above to delineate ChatMessages + # Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model. + tokenizer: PreTrainedTokenizer + if os.environ.get("HF_TOKEN"): + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + else: + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") + logger.warning( + "Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for " + "estimating the prompt length because no HF_TOKEN was found. Using " + "NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var " + "HF_TOKEN containing a Hugging Face token and make sure you have access to the model." + ) + + self.prompt_handler = DefaultPromptHandler( + tokenizer=tokenizer, + model_max_length=model_max_length, + max_length=self.generation_kwargs.get("max_tokens") or 512, + ) + + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: + """ + Prepares the body for the Mistral request. + + :param messages: The chat messages to package into the request. + :param inference_kwargs: Additional inference kwargs to use. + :returns: The prepared body. + """ + default_params = { + "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required + } + # replace stop_words from inference_kwargs with stop, as this is Mistral specific parameter + stop_words = inference_kwargs.pop("stop_words", []) + if stop_words: + inference_kwargs["stop"] = stop_words + params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) + body = {"prompt": self.prepare_chat_messages(messages=messages), **params} + return body + + def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: + """ + Prepares the chat messages for the Mistral request. + + :param messages: The chat messages to prepare. + :returns: The prepared chat messages as a string. + """ + # it would be great to use the default mistral chat template, but it doesn't support system messages + # the class variable defined chat_template is a workaround to support system messages + # default is https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json + # but we'll use our custom chat template + prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( + conversation=[self.to_openai_format(m) for m in messages], tokenize=False, chat_template=self.chat_template + ) + if self.truncate: + prepared_prompt = self._ensure_token_limit(prepared_prompt) + return prepared_prompt + + def to_openai_format(self, m: ChatMessage) -> Dict[str, Any]: + """ + Convert the message to the format expected by OpenAI's Chat API. + See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details. + + :returns: A dictionary with the following key: + - `role` + - `content` + - `name` (optional) + """ + msg = {"role": m.role.value, "content": m.content} + if m.name: + msg["name"] = m.name + return msg + + def check_prompt(self, prompt: str) -> Dict[str, Any]: + """ + Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. + + :param prompt: The prompt to check. + :returns: A dictionary containing the resized prompt and additional information. + """ + return self.prompt_handler(prompt) + + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + """ + Extracts the messages from the response body. + + :param response_body: The response body. + :return: The extracted ChatMessage list. + """ + messages: List[ChatMessage] = [] + responses = response_body.get("outputs", []) + for response in responses: + meta = {k: v for k, v in response.items() if k not in ["text"]} + messages.append(ChatMessage.from_assistant(response["text"], meta=meta)) + return messages + + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: + """ + Extracts the content and meta from a streaming chunk. + + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. + """ + response_chunk = chunk.get("outputs", []) + if response_chunk: + return StreamingChunk(content=response_chunk[0].get("text", ""), meta=chunk) + return StreamingChunk(content="", meta=chunk) + + +class MetaLlama2ChatAdapter(BedrockModelChatAdapter): + """ + Model adapter for the Meta Llama 2 models. + """ + + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html + ALLOWED_PARAMS: ClassVar[List[str]] = ["max_gen_len", "temperature", "top_p"] + + chat_template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" + "{% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" + ) + + def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: + """ + Initializes the Meta Llama 2 chat adapter. + :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. + :param generation_kwargs: The generation kwargs. + """ + super().__init__(truncate, generation_kwargs) + # We pop the model_max_length as it is not sent to the model + # but used to truncate the prompt if needed + # Llama 2 has context window size of 4096 tokens + # with some exceptions when the context window has been extended + model_max_length = self.generation_kwargs.pop("model_max_length", 4096) + + # Use `google/flan-t5-base` as it's also BPE sentencepiece tokenizer just like llama 2 + # a) we should get good estimates for the prompt length (empirically close to llama 2) + # b) we can use apply_chat_template with the template above to delineate ChatMessages + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") + tokenizer.bos_token = "" + tokenizer.eos_token = "" + tokenizer.unk_token = "" + self.prompt_handler = DefaultPromptHandler( + tokenizer=tokenizer, + model_max_length=model_max_length, + max_length=self.generation_kwargs.get("max_gen_len") or 512, + ) + + def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: + """ + Prepares the body for the Meta Llama 2 request. + + :param messages: The chat messages to package into the request. + :param inference_kwargs: Additional inference kwargs to use. + """ + default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512} + + # no support for stop words in Meta Llama 2 + params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) + body = {"prompt": self.prepare_chat_messages(messages=messages), **params} + return body + + def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: + """ + Prepares the chat messages for the Meta Llama 2 request. + + :param messages: The chat messages to prepare. + :returns: The prepared chat messages as a string ready for the model. + """ + prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( + conversation=messages, tokenize=False, chat_template=self.chat_template + ) + + if self.truncate: + prepared_prompt = self._ensure_token_limit(prepared_prompt) + return prepared_prompt + + def check_prompt(self, prompt: str) -> Dict[str, Any]: + """ + Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. + + :param prompt: The prompt to check. + :returns: A dictionary containing the resized prompt and additional information. + + """ + return self.prompt_handler(prompt) + + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: + """ + Extracts the messages from the response body. + + :param response_body: The response body. + :return: The extracted ChatMessage list. + """ + message_tag = "generation" + metadata = {k: v for (k, v) in response_body.items() if k != message_tag} + return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] + + def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: + """ + Extracts the content and meta from a streaming chunk. + + :param chunk: The streaming chunk as dict. + :returns: A StreamingChunk object. + """ + return StreamingChunk(content=chunk.get("generation", ""), meta=chunk) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py new file mode 100644 index 000000000..00748dc37 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -0,0 +1,270 @@ +import json +import logging +import re +from typing import Any, Callable, ClassVar, Dict, List, Optional, Type + +from botocore.exceptions import ClientError +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.utils.auth import Secret, deserialize_secrets_inplace +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable + +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, +) +from haystack_integrations.common.amazon_bedrock.utils import get_aws_session + +from .adapters import AnthropicClaudeChatAdapter, BedrockModelConverseAdapter, MetaLlama2ChatAdapter, MistralChatAdapter + +logger = logging.getLogger(__name__) + + +@component +class AmazonBedrockChatGenerator: + """ + Completes chats using LLMs hosted on Amazon Bedrock. + + For example, to use the Anthropic Claude 3 Sonnet model, initialize this component with the + 'anthropic.claude-3-sonnet-20240229-v1:0' model name. + + ### Usage example + + ```python + from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator + from haystack.dataclasses import ChatMessage + from haystack.components.generators.utils import print_streaming_chunk + + messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant, answer in German only"), + ChatMessage.from_user("What's Natural Language Processing?")] + + + client = AmazonBedrockChatGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", + streaming_callback=print_streaming_chunk) + client.run(messages, generation_kwargs={"max_tokens": 512}) + + ``` + + AmazonBedrockChatGenerator uses AWS for authentication. You can use the AWS CLI to authenticate through your IAM. + For more information on setting up an IAM identity-based policy, see [Amazon Bedrock documentation] + (https://docs.aws.amazon.com/bedrock/latest/userguide/security_iam_id-based-policy-examples.html). + + If the AWS environment is configured correctly, the AWS credentials are not required as they're loaded + automatically from the environment or the AWS configuration file. + If the AWS environment is not configured, set `aws_access_key_id`, `aws_secret_access_key`, + and `aws_region_name` as environment variables or pass them as + [Secret](https://docs.haystack.deepset.ai/v2.0/docs/secret-management) arguments. Make sure the region you set + supports Amazon Bedrock. + """ + + SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { + r"anthropic.claude.*": AnthropicClaudeChatAdapter, + r"meta.llama2.*": MetaLlama2ChatAdapter, + r"mistral.*": MistralChatAdapter, + } + + def __init__( + self, + model: str, + aws_access_key_id: Optional[Secret] = Secret.from_env_var(["AWS_ACCESS_KEY_ID"], strict=False), # noqa: B008 + aws_secret_access_key: Optional[Secret] = Secret.from_env_var( # noqa: B008 + ["AWS_SECRET_ACCESS_KEY"], strict=False + ), + aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 + aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 + aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 + generation_kwargs: Optional[Dict[str, Any]] = None, + stop_words: Optional[List[str]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + truncate: Optional[bool] = True, + ): + """ + Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the + Amazon Bedrock client. + + Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded + automatically from the environment or the AWS configuration file and do not need to be provided explicitly via + the constructor. If the AWS environment is not configured users need to provide the AWS credentials via the + constructor. Aside from model, three required parameters are `aws_access_key_id`, `aws_secret_access_key`, + and `aws_region_name`. + + :param model: The model to use for text generation. The model must be available in Amazon Bedrock and must + be specified in the format outlined in the [Amazon Bedrock documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html). + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. Make sure the region you set supports Amazon Bedrock. + :param aws_profile_name: AWS profile name. + :param generation_kwargs: Keyword arguments sent to the model. These + parameters are specific to a model. You can find them in the model's documentation. + For example, you can find the + Anthropic Claude generation parameters in [Anthropic documentation](https://docs.anthropic.com/claude/reference/complete_post). + :param stop_words: A list of stop words that stop the model from generating more text + when encountered. You can provide them using + this parameter or using the model's `generation_kwargs` under a model's specific key for stop words. + For example, you can provide + stop words for Anthropic Claude in the `stop_sequences` key. + :param streaming_callback: A callback function called when a new token is received from the stream. + By default, the model is not set up for streaming. To enable streaming, set this parameter to a callback + function that handles the streaming chunks. The callback function receives a + [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and + switches the streaming mode on. + :param truncate: Whether to truncate the prompt messages or not. + """ + if not model: + msg = "'model' cannot be None or empty string" + raise ValueError(msg) + self.model = model + self.aws_access_key_id = aws_access_key_id + self.aws_secret_access_key = aws_secret_access_key + self.aws_session_token = aws_session_token + self.aws_region_name = aws_region_name + self.aws_profile_name = aws_profile_name + self.truncate = truncate + + # get the model adapter for the given model + model_adapter_cls = self.get_model_adapter(model=model) + if not model_adapter_cls: + msg = f"AmazonBedrockGenerator doesn't support the model {model}." + raise AmazonBedrockConfigurationError(msg) + self.model_adapter = model_adapter_cls(self.truncate, generation_kwargs or {}) + + # create the AWS session and client + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: + return secret.resolve_value() if secret else None + + try: + session = get_aws_session( + aws_access_key_id=resolve_secret(aws_access_key_id), + aws_secret_access_key=resolve_secret(aws_secret_access_key), + aws_session_token=resolve_secret(aws_session_token), + aws_region_name=resolve_secret(aws_region_name), + aws_profile_name=resolve_secret(aws_profile_name), + ) + self.client = session.client("bedrock-runtime") + except Exception as exception: + msg = ( + "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " + "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" + ) + raise AmazonBedrockConfigurationError(msg) from exception + + self.stop_words = stop_words or [] + self.streaming_callback = streaming_callback + + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. + + :param messages: The messages to generate a response to. + :param streaming_callback: + A callback function that is called when a new token is received from the stream. + :param generation_kwargs: Additional generation keyword arguments passed to the model. + :returns: A dictionary with the following keys: + - `replies`: The generated List of `ChatMessage` objects. + """ + generation_kwargs = generation_kwargs or {} + generation_kwargs = generation_kwargs.copy() + + streaming_callback = streaming_callback or self.streaming_callback + generation_kwargs["stream"] = streaming_callback is not None + + # check if the prompt is a list of ChatMessage objects + if not ( + isinstance(messages, list) + and len(messages) > 0 + and all(isinstance(message, ChatMessage) for message in messages) + ): + msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." + raise ValueError(msg) + + body = self.model_adapter.prepare_body( + messages=messages, **{"stop_words": self.stop_words, **generation_kwargs} + ) + try: + if streaming_callback: + response = self.client.invoke_model_with_response_stream( + body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" + ) + response_stream = response["body"] + replies = self.model_adapter.get_stream_responses( + stream=response_stream, streaming_callback=streaming_callback + ) + else: + response = self.client.invoke_model( + body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" + ) + response_body = json.loads(response.get("body").read().decode("utf-8")) + replies = self.model_adapter.get_responses(response_body=response_body) + except ClientError as exception: + msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" + raise AmazonBedrockInferenceError(msg) from exception + + # rename the meta key to be inline with OpenAI meta output keys + for response in replies: + if response.meta is not None and "usage" in response.meta: + response.meta["usage"]["prompt_tokens"] = response.meta["usage"].pop("input_tokens") + response.meta["usage"]["completion_tokens"] = response.meta["usage"].pop("output_tokens") + + return {"replies": replies} + + @classmethod + def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]]: + """ + Returns the model adapter for the given model. + + :param model: The model to get the adapter for. + :returns: The model adapter for the given model, or None if the model is not supported. + """ + for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): + if re.fullmatch(pattern, model): + return adapter + return None + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None, + aws_secret_access_key=self.aws_secret_access_key.to_dict() if self.aws_secret_access_key else None, + aws_session_token=self.aws_session_token.to_dict() if self.aws_session_token else None, + aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, + aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, + model=self.model, + stop_words=self.stop_words, + generation_kwargs=self.model_adapter.generation_kwargs, + streaming_callback=callback_name, + truncate=self.truncate, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + deserialize_secrets_inplace( + data["init_parameters"], + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], + ) + return default_from_dict(cls, data) From 946d9ac99fa655331c5d19ff33579e6aaa7b08ac Mon Sep 17 00:00:00 2001 From: FloRul Date: Tue, 13 Aug 2024 21:12:35 -0400 Subject: [PATCH 02/35] modify supported model patterns for supported tools model patterns --- .../amazon_bedrock/converse/adapters.py | 565 ------------------ .../converse/converse_generator.py | 40 +- 2 files changed, 10 insertions(+), 595 deletions(-) delete mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/adapters.py diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/adapters.py deleted file mode 100644 index cace89fe7..000000000 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/adapters.py +++ /dev/null @@ -1,565 +0,0 @@ -import json -import logging -import os -from abc import ABC, abstractmethod -from typing import Any, Callable, ClassVar, Dict, List, Optional - -from botocore.eventstream import EventStream -from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk -from transformers import AutoTokenizer, PreTrainedTokenizer - -from haystack_integrations.components.generators.amazon_bedrock.handlers import DefaultPromptHandler - -logger = logging.getLogger(__name__) - - -class BedrockModelConverseAdapter(ABC): - """ - Base class for Amazon Bedrock chat model adapters. - - Each subclass of this class is designed to address the unique specificities of a particular chat LLM it adapts, - focusing on preparing the requests and extracting the responses from the Amazon Bedrock hosted chat LLMs. - """ - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: - """ - Initializes the chat adapter with the truncate parameter and generation kwargs. - """ - self.generation_kwargs = generation_kwargs - self.truncate = truncate - - @abstractmethod - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Amazon Bedrock request. - Subclasses should override this method to package the chat messages into the request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - :returns: The prepared body. - """ - - def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the responses from the Amazon Bedrock response. - - :param response_body: The response body. - :returns: The extracted responses. - """ - return self._extract_messages_from_response(response_body) - - def get_stream_responses( - self, stream: EventStream, streaming_callback: Callable[[StreamingChunk], None] - ) -> List[ChatMessage]: - streaming_chunks: List[StreamingChunk] = [] - last_decoded_chunk: Dict[str, Any] = {} - for event in stream: - chunk = event.get("chunk") - if chunk: - last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) - streaming_chunk = self._build_streaming_chunk(last_decoded_chunk) - streaming_callback(streaming_chunk) # callback the stream handler with StreamingChunk - streaming_chunks.append(streaming_chunk) - responses = ["".join(chunk.content for chunk in streaming_chunks).lstrip()] - return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses] - - @staticmethod - def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any], allowed_params: List[str]) -> None: - """ - Updates target_dict with values from updates_dict. Merges lists instead of overriding them. - - :param target_dict: The dictionary to update. - :param updates_dict: The dictionary with updates. - :param allowed_params: The list of allowed params to use. - """ - for key, value in updates_dict.items(): - if key not in allowed_params: - logger.warning(f"Parameter '{key}' is not allowed and will be ignored.") - continue - if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list): - # Merge lists and remove duplicates - target_dict[key] = sorted(set(target_dict[key] + value)) - else: - # Override the value in target_dict - target_dict[key] = value - - def _get_params( - self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any], allowed_params: List[str] - ) -> Dict[str, Any]: - """ - Merges params from inference_kwargs with the default params and self.generation_kwargs. - Uses a helper function to merge lists or override values as necessary. - - :param inference_kwargs: The inference kwargs to merge. - :param default_params: The default params to start with. - :param allowed_params: The list of allowed params to use. - :returns: The merged params. - """ - # Start with a copy of default_params - kwargs = default_params.copy() - - # Update the default params with self.generation_kwargs and finally inference_kwargs - self._update_params(kwargs, self.generation_kwargs, allowed_params) - self._update_params(kwargs, inference_kwargs, allowed_params) - - return kwargs - - def _ensure_token_limit(self, prompt: str) -> str: - """ - Ensures that the prompt is within the token limit for the model. - :param prompt: The prompt to check. - :returns: The resized prompt. - """ - resize_info = self.check_prompt(prompt) - if resize_info["prompt_length"] != resize_info["new_prompt_length"]: - logger.warning( - "The prompt was truncated from %s tokens to %s tokens so that the prompt length and " - "the answer length (%s tokens) fit within the model's max token limit (%s tokens). " - "Shorten the prompt or it will be cut off.", - resize_info["prompt_length"], - max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore - resize_info["max_length"], - resize_info["model_max_length"], - ) - return str(resize_info["resized_prompt"]) - - @abstractmethod - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - """ - - @abstractmethod - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :returns: The extracted ChatMessage list. - """ - - @abstractmethod - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - - -class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): - """ - Model adapter for the Anthropic Claude chat model. - """ - - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - ALLOWED_PARAMS: ClassVar[List[str]] = [ - "anthropic_version", - "max_tokens", - "stop_sequences", - "temperature", - "top_p", - "top_k", - "system", - ] - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): - """ - Initializes the Anthropic Claude chat adapter. - - :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. - :param generation_kwargs: The generation kwargs. - """ - super().__init__(truncate, generation_kwargs) - - # We pop the model_max_length as it is not sent to the model - # but used to truncate the prompt if needed - # Anthropic Claude has a limit of at least 100000 tokens - # https://docs.anthropic.com/claude/reference/input-and-output-sizes - model_max_length = self.generation_kwargs.pop("model_max_length", 100000) - - # Truncate prompt if prompt tokens > model_max_length-max_length - # (max_length is the length of the generated text) - # TODO use Anthropic tokenizer to get the precise prompt length - # See https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#token-counting - self.prompt_handler = DefaultPromptHandler( - tokenizer="gpt2", - model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_tokens") or 512, - ) - - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Anthropic Claude request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - :returns: The prepared body. - """ - default_params = { - "anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31", - "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required - } - - # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it - stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) - if stop_sequences: - inference_kwargs["stop_sequences"] = stop_sequences - params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) - body = {**self.prepare_chat_messages(messages=messages), **params} - return body - - def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: - """ - Prepares the chat messages for the Anthropic Claude request. - - :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a dictionary. - """ - body: Dict[str, Any] = {} - system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None - body["messages"] = [ - self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT) - ] - if system: - body["system"] = system - # Ensure token limit for each message in the body - if self.truncate: - for message in body["messages"]: - for content in message["content"]: - content["text"] = self._ensure_token_limit(content["text"]) - return body - - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - """ - return self.prompt_handler(prompt) - - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :return: The extracted ChatMessage list. - """ - messages: List[ChatMessage] = [] - if response_body.get("type") == "message": - for content in response_body["content"]: - if content.get("type") == "text": - meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} - messages.append(ChatMessage.from_assistant(content["text"], meta=meta)) - return messages - - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta": - return StreamingChunk(content=chunk.get("delta", {}).get("text", ""), meta=chunk) - return StreamingChunk(content="", meta=chunk) - - def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]: - """ - Convert a ChatMessage to a dictionary with the content and role fields. - :param m: The ChatMessage to convert. - :return: The dictionary with the content and role fields. - """ - return {"content": [{"type": "text", "text": m.content}], "role": m.role.value} - - -class MistralChatAdapter(BedrockModelChatAdapter): - """ - Model adapter for the Mistral chat model. - """ - - chat_template = """ - {% if messages[0]['role'] == 'system' %} - {% set loop_messages = messages[1:] %} - {% set system_message = messages[0]['content'] %} - {% else %} - {% set loop_messages = messages %} - {% set system_message = false %} - {% endif %} - {{bos_token}} - {% for message in loop_messages %} - {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} - {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} - {% endif %} - {% if loop.index0 == 0 and system_message != false %} - {% set content = system_message + '\n' + message['content'] %} - {% else %} - {% set content = message['content'] %} - {% endif %} - {% if message['role'] == 'user' %} - {{ '[INST] ' + content.strip() + ' [/INST]' }} - {% elif message['role'] == 'assistant' %} - {{ content.strip() + eos_token }} - {% endif %} - {% endfor %} - """ - chat_template = "".join(line.strip() for line in chat_template.splitlines()) - - # the above template was designed to match https://docs.mistral.ai/models/#chat-template - # and to support system messages, otherwise we could use the default mistral chat template - # available on HF infrastructure - - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - ALLOWED_PARAMS: ClassVar[List[str]] = [ - "max_tokens", - "safe_prompt", - "random_seed", - "temperature", - "top_p", - ] - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]): - """ - Initializes the Mistral chat adapter. - :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. - :param generation_kwargs: The generation kwargs. - """ - super().__init__(truncate, generation_kwargs) - - # We pop the model_max_length as it is not sent to the model - # but used to truncate the prompt if needed - # Mistral has a limit of at least 32000 tokens - model_max_length = self.generation_kwargs.pop("model_max_length", 32000) - - # Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer - # a) we should get good estimates for the prompt length - # b) we can use apply_chat_template with the template above to delineate ChatMessages - # Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model. - tokenizer: PreTrainedTokenizer - if os.environ.get("HF_TOKEN"): - tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") - else: - tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") - logger.warning( - "Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for " - "estimating the prompt length because no HF_TOKEN was found. Using " - "NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var " - "HF_TOKEN containing a Hugging Face token and make sure you have access to the model." - ) - - self.prompt_handler = DefaultPromptHandler( - tokenizer=tokenizer, - model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_tokens") or 512, - ) - - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Mistral request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - :returns: The prepared body. - """ - default_params = { - "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required - } - # replace stop_words from inference_kwargs with stop, as this is Mistral specific parameter - stop_words = inference_kwargs.pop("stop_words", []) - if stop_words: - inference_kwargs["stop"] = stop_words - params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) - body = {"prompt": self.prepare_chat_messages(messages=messages), **params} - return body - - def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: - """ - Prepares the chat messages for the Mistral request. - - :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a string. - """ - # it would be great to use the default mistral chat template, but it doesn't support system messages - # the class variable defined chat_template is a workaround to support system messages - # default is https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json - # but we'll use our custom chat template - prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=[self.to_openai_format(m) for m in messages], tokenize=False, chat_template=self.chat_template - ) - if self.truncate: - prepared_prompt = self._ensure_token_limit(prepared_prompt) - return prepared_prompt - - def to_openai_format(self, m: ChatMessage) -> Dict[str, Any]: - """ - Convert the message to the format expected by OpenAI's Chat API. - See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details. - - :returns: A dictionary with the following key: - - `role` - - `content` - - `name` (optional) - """ - msg = {"role": m.role.value, "content": m.content} - if m.name: - msg["name"] = m.name - return msg - - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - """ - return self.prompt_handler(prompt) - - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :return: The extracted ChatMessage list. - """ - messages: List[ChatMessage] = [] - responses = response_body.get("outputs", []) - for response in responses: - meta = {k: v for k, v in response.items() if k not in ["text"]} - messages.append(ChatMessage.from_assistant(response["text"], meta=meta)) - return messages - - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - response_chunk = chunk.get("outputs", []) - if response_chunk: - return StreamingChunk(content=response_chunk[0].get("text", ""), meta=chunk) - return StreamingChunk(content="", meta=chunk) - - -class MetaLlama2ChatAdapter(BedrockModelChatAdapter): - """ - Model adapter for the Meta Llama 2 models. - """ - - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html - ALLOWED_PARAMS: ClassVar[List[str]] = ["max_gen_len", "temperature", "top_p"] - - chat_template = ( - "{% if messages[0]['role'] == 'system' %}" - "{% set loop_messages = messages[1:] %}" - "{% set system_message = messages[0]['content'] %}" - "{% else %}" - "{% set loop_messages = messages %}" - "{% set system_message = false %}" - "{% endif %}" - "{% for message in loop_messages %}" - "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" - "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" - "{% endif %}" - "{% if loop.index0 == 0 and system_message != false %}" - "{% set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %}" - "{% else %}" - "{% set content = message['content'] %}" - "{% endif %}" - "{% if message['role'] == 'user' %}" - "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" - "{% elif message['role'] == 'assistant' %}" - "{{ ' ' + content.strip() + ' ' + eos_token }}" - "{% endif %}" - "{% endfor %}" - ) - - def __init__(self, truncate: Optional[bool], generation_kwargs: Dict[str, Any]) -> None: - """ - Initializes the Meta Llama 2 chat adapter. - :param truncate: Whether to truncate the prompt if it exceeds the model's max token limit. - :param generation_kwargs: The generation kwargs. - """ - super().__init__(truncate, generation_kwargs) - # We pop the model_max_length as it is not sent to the model - # but used to truncate the prompt if needed - # Llama 2 has context window size of 4096 tokens - # with some exceptions when the context window has been extended - model_max_length = self.generation_kwargs.pop("model_max_length", 4096) - - # Use `google/flan-t5-base` as it's also BPE sentencepiece tokenizer just like llama 2 - # a) we should get good estimates for the prompt length (empirically close to llama 2) - # b) we can use apply_chat_template with the template above to delineate ChatMessages - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") - tokenizer.bos_token = "" - tokenizer.eos_token = "" - tokenizer.unk_token = "" - self.prompt_handler = DefaultPromptHandler( - tokenizer=tokenizer, - model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_gen_len") or 512, - ) - - def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: - """ - Prepares the body for the Meta Llama 2 request. - - :param messages: The chat messages to package into the request. - :param inference_kwargs: Additional inference kwargs to use. - """ - default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512} - - # no support for stop words in Meta Llama 2 - params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) - body = {"prompt": self.prepare_chat_messages(messages=messages), **params} - return body - - def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: - """ - Prepares the chat messages for the Meta Llama 2 request. - - :param messages: The chat messages to prepare. - :returns: The prepared chat messages as a string ready for the model. - """ - prepared_prompt: str = self.prompt_handler.tokenizer.apply_chat_template( - conversation=messages, tokenize=False, chat_template=self.chat_template - ) - - if self.truncate: - prepared_prompt = self._ensure_token_limit(prepared_prompt) - return prepared_prompt - - def check_prompt(self, prompt: str) -> Dict[str, Any]: - """ - Checks the prompt length and resizes it if necessary. If the prompt is too long, it will be truncated. - - :param prompt: The prompt to check. - :returns: A dictionary containing the resized prompt and additional information. - - """ - return self.prompt_handler(prompt) - - def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: - """ - Extracts the messages from the response body. - - :param response_body: The response body. - :return: The extracted ChatMessage list. - """ - message_tag = "generation" - metadata = {k: v for (k, v) in response_body.items() if k != message_tag} - return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] - - def _build_streaming_chunk(self, chunk: Dict[str, Any]) -> StreamingChunk: - """ - Extracts the content and meta from a streaming chunk. - - :param chunk: The streaming chunk as dict. - :returns: A StreamingChunk object. - """ - return StreamingChunk(content=chunk.get("generation", ""), meta=chunk) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 00748dc37..2d6c51539 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -15,15 +15,15 @@ ) from haystack_integrations.common.amazon_bedrock.utils import get_aws_session -from .adapters import AnthropicClaudeChatAdapter, BedrockModelConverseAdapter, MetaLlama2ChatAdapter, MistralChatAdapter logger = logging.getLogger(__name__) @component -class AmazonBedrockChatGenerator: +class AmazonBedrockConverseGenerator: """ - Completes chats using LLMs hosted on Amazon Bedrock. + Completes chats using LLMs hosted on Amazon Bedrock using the converse api. + References: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html For example, to use the Anthropic Claude 3 Sonnet model, initialize this component with the 'anthropic.claude-3-sonnet-20240229-v1:0' model name. @@ -57,11 +57,11 @@ class AmazonBedrockChatGenerator: supports Amazon Bedrock. """ - SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { - r"anthropic.claude.*": AnthropicClaudeChatAdapter, - r"meta.llama2.*": MetaLlama2ChatAdapter, - r"mistral.*": MistralChatAdapter, - } + SUPPORTED_TOOL_MODEL_PATTERNS: ClassVar[List[str]] = [ + r"anthropic.claude-3.*", + r"cohere.command-r.*", + r"mistral.mistral-large.*", + ] def __init__( self, @@ -79,7 +79,7 @@ def __init__( truncate: Optional[bool] = True, ): """ - Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the + Initializes the `AmazonBedrockConverseGenerator` with the provided parameters. The parameters are passed to the Amazon Bedrock client. Note that the AWS credentials are not required if the AWS environment is configured correctly. These are loaded @@ -122,13 +122,6 @@ def __init__( self.aws_profile_name = aws_profile_name self.truncate = truncate - # get the model adapter for the given model - model_adapter_cls = self.get_model_adapter(model=model) - if not model_adapter_cls: - msg = f"AmazonBedrockGenerator doesn't support the model {model}." - raise AmazonBedrockConfigurationError(msg) - self.model_adapter = model_adapter_cls(self.truncate, generation_kwargs or {}) - # create the AWS session and client def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -214,19 +207,6 @@ def run( return {"replies": replies} - @classmethod - def get_model_adapter(cls, model: str) -> Optional[Type[BedrockModelChatAdapter]]: - """ - Returns the model adapter for the given model. - - :param model: The model to get the adapter for. - :returns: The model adapter for the given model, or None if the model is not supported. - """ - for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): - if re.fullmatch(pattern, model): - return adapter - return None - def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -250,7 +230,7 @@ def to_dict(self) -> Dict[str, Any]: ) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockChatGenerator": + def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockConverseGenerator": """ Deserializes the component from a dictionary. From c54de1882b00bc75077ce40d44ea89f388520897 Mon Sep 17 00:00:00 2001 From: FloRul Date: Wed, 14 Aug 2024 21:43:59 -0400 Subject: [PATCH 03/35] add unsupported models for chat patterns --- .../converse/converse_generator.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 2d6c51539..2782f1c40 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -57,12 +57,19 @@ class AmazonBedrockConverseGenerator: supports Amazon Bedrock. """ + # according to the list provided in the toolConfig arg: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestSyntax SUPPORTED_TOOL_MODEL_PATTERNS: ClassVar[List[str]] = [ r"anthropic.claude-3.*", r"cohere.command-r.*", r"mistral.mistral-large.*", ] + UNSUPPORTED_CHAT_MODEL_PATTERNS: ClassVar[List[str]] = [ + r"cohere.command-text.*", + r"cohere.command-light.*", + r"ai21.j2.*", + ] + def __init__( self, model: str, @@ -73,8 +80,11 @@ def __init__( aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 - generation_kwargs: Optional[Dict[str, Any]] = None, - stop_words: Optional[List[str]] = None, + # converse parameters + # inference_config: maxTokens, stopSequences, temperature,topP + inference_config: Optional[Dict[str, Any]] = None, + additionalModelRequestFields: Optional[Dict[str, Any]] = None, + tool_config: Optional[Dict[str, Any]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, truncate: Optional[bool] = True, ): @@ -96,9 +106,7 @@ def __init__( :param aws_region_name: AWS region name. Make sure the region you set supports Amazon Bedrock. :param aws_profile_name: AWS profile name. :param generation_kwargs: Keyword arguments sent to the model. These - parameters are specific to a model. You can find them in the model's documentation. - For example, you can find the - Anthropic Claude generation parameters in [Anthropic documentation](https://docs.anthropic.com/claude/reference/complete_post). + parameters are specific to a model. You can find them in the [converse documentation](). :param stop_words: A list of stop words that stop the model from generating more text when encountered. You can provide them using this parameter or using the model's `generation_kwargs` under a model's specific key for stop words. @@ -151,6 +159,7 @@ def run( messages: List[ChatMessage], streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, + toolConfig: Optional[Dict[str, Any]] = None, ): """ Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. @@ -168,6 +177,13 @@ def run( streaming_callback = streaming_callback or self.streaming_callback generation_kwargs["stream"] = streaming_callback is not None + # warn and only keep last message if model does not support chat + if re.match("|".join(self.UNSUPPORTED_CHAT_MODEL_PATTERNS), self.model) and len(messages) > 1: + logging.warning( + f"The model {self.model} does not support chat. Only the last message " "will be taken into account." + ) + messages = messages[-1:] + # check if the prompt is a list of ChatMessage objects if not ( isinstance(messages, list) From dfa27bfc8d710f63391f2f29999d6288d6c9f27f Mon Sep 17 00:00:00 2001 From: FloRul Date: Wed, 14 Aug 2024 23:23:14 -0400 Subject: [PATCH 04/35] add converseMessage class --- .../converse/converse_generator.py | 29 ++++---- .../amazon_bedrock/converse/utils.py | 69 +++++++++++++++++++ 2 files changed, 85 insertions(+), 13 deletions(-) create mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 2782f1c40..1d3c43e21 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -14,10 +14,14 @@ AmazonBedrockInferenceError, ) from haystack_integrations.common.amazon_bedrock.utils import get_aws_session - +from utils import ContentBlock logger = logging.getLogger(__name__) +class ConverseMessage: + def __init__(self, role: str, content: ContentBlock): + self.role = role + self.content = content @component class AmazonBedrockConverseGenerator: @@ -129,6 +133,9 @@ def __init__( self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name self.truncate = truncate + self.inference_config = inference_config + self.additionalModelRequestFields = additionalModelRequestFields + self.tool_config = tool_config # create the AWS session and client def resolve_secret(secret: Optional[Secret]) -> Optional[str]: @@ -150,16 +157,16 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: ) raise AmazonBedrockConfigurationError(msg) from exception - self.stop_words = stop_words or [] self.streaming_callback = streaming_callback @component.output_types(replies=List[ChatMessage]) def run( self, - messages: List[ChatMessage], + messages: List[ConverseMessage], streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - generation_kwargs: Optional[Dict[str, Any]] = None, - toolConfig: Optional[Dict[str, Any]] = None, + inference_config: Optional[Dict[str, Any]] = None, + additionalModelRequestFields: Optional[Dict[str, Any]] = None, + tool_config: Optional[Dict[str, Any]] = None, ): """ Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. @@ -171,11 +178,10 @@ def run( :returns: A dictionary with the following keys: - `replies`: The generated List of `ChatMessage` objects. """ - generation_kwargs = generation_kwargs or {} - generation_kwargs = generation_kwargs.copy() + inference_config = inference_config or {} + inference_config = inference_config.copy() streaming_callback = streaming_callback or self.streaming_callback - generation_kwargs["stream"] = streaming_callback is not None # warn and only keep last message if model does not support chat if re.match("|".join(self.UNSUPPORTED_CHAT_MODEL_PATTERNS), self.model) and len(messages) > 1: @@ -193,9 +199,7 @@ def run( msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." raise ValueError(msg) - body = self.model_adapter.prepare_body( - messages=messages, **{"stop_words": self.stop_words, **generation_kwargs} - ) + body = {"messages":} try: if streaming_callback: response = self.client.invoke_model_with_response_stream( @@ -239,8 +243,7 @@ def to_dict(self) -> Dict[str, Any]: aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, - stop_words=self.stop_words, - generation_kwargs=self.model_adapter.generation_kwargs, + inference_config=self.inference_config, streaming_callback=callback_name, truncate=self.truncate, ) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py new file mode 100644 index 000000000..e633d3635 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import Union, Optional + + +@dataclass +class DocumentBlock: + # Placeholder for DocumentBlock attributes + pass + + +@dataclass +class GuardrailConverseContentBlock: + # Placeholder for GuardrailConverseContentBlock attributes + pass + + +@dataclass +class ImageBlock: + # Placeholder for ImageBlock attributes + pass + + +@dataclass +class ToolResultBlock: + # Placeholder for ToolResultBlock attributes + pass + + +@dataclass +class ToolUseBlock: + # Placeholder for ToolUseBlock attributes + pass + + +@dataclass +class ContentBlock: + content: Union[DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, ToolUseBlock] + + def __post_init__(self): + if not isinstance( + self.content, (DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, ToolUseBlock) + ): + raise ValueError( + "Invalid content type. Must be one of DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, or ToolUseBlock" + ) + + @property + def document(self) -> Optional[DocumentBlock]: + return self.content if isinstance(self.content, DocumentBlock) else None + + @property + def guard_content(self) -> Optional[GuardrailConverseContentBlock]: + return self.content if isinstance(self.content, GuardrailConverseContentBlock) else None + + @property + def image(self) -> Optional[ImageBlock]: + return self.content if isinstance(self.content, ImageBlock) else None + + @property + def text(self) -> Optional[str]: + return self.content if isinstance(self.content, str) else None + + @property + def tool_result(self) -> Optional[ToolResultBlock]: + return self.content if isinstance(self.content, ToolResultBlock) else None + + @property + def tool_use(self) -> Optional[ToolUseBlock]: + return self.content if isinstance(self.content, ToolUseBlock) else None From bd9160022698fd1b8f72737a9e6d13831e87ff50 Mon Sep 17 00:00:00 2001 From: FloRul Date: Thu, 15 Aug 2024 00:13:18 -0400 Subject: [PATCH 05/35] add utils and call converse with only text --- .../converse/converse_generator.py | 20 +++--- .../amazon_bedrock/converse/utils.py | 65 ++++++++++++++++--- 2 files changed, 65 insertions(+), 20 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 1d3c43e21..3d8a9cf3a 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -14,14 +14,10 @@ AmazonBedrockInferenceError, ) from haystack_integrations.common.amazon_bedrock.utils import get_aws_session -from utils import ContentBlock +from utils import ConverseMessage logger = logging.getLogger(__name__) -class ConverseMessage: - def __init__(self, role: str, content: ContentBlock): - self.role = role - self.content = content @component class AmazonBedrockConverseGenerator: @@ -190,16 +186,16 @@ def run( ) messages = messages[-1:] - # check if the prompt is a list of ChatMessage objects + # check if the prompt is a list of ConverseMessage objects if not ( isinstance(messages, list) and len(messages) > 0 - and all(isinstance(message, ChatMessage) for message in messages) + and all(isinstance(message, ConverseMessage) for message in messages) ): - msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." + msg = f"The model {self.model} requires a list of ConverseMessage objects as a prompt." raise ValueError(msg) - body = {"messages":} + body = {"messages": message.to_dict() for message in messages} try: if streaming_callback: response = self.client.invoke_model_with_response_stream( @@ -210,10 +206,10 @@ def run( stream=response_stream, streaming_callback=streaming_callback ) else: - response = self.client.invoke_model( - body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" + response = self.client.converse( + modelId=self.model, + messages=[message.to_dict() for message in messages], ) - response_body = json.loads(response.get("body").read().decode("utf-8")) replies = self.model_adapter.get_responses(response_body=response_body) except ClientError as exception: msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index e633d3635..1767542e4 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import Union, Optional +from enum import Enum +from typing import List, Union, Optional @dataclass @@ -32,17 +33,35 @@ class ToolUseBlock: pass +from dataclasses import dataclass, asdict +from typing import Union, Optional + + @dataclass class ContentBlock: - content: Union[DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, ToolUseBlock] + content: List[Union[DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, ToolUseBlock]] def __post_init__(self): - if not isinstance( - self.content, (DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, ToolUseBlock) - ): - raise ValueError( - "Invalid content type. Must be one of DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, or ToolUseBlock" - ) + if not isinstance(self.content, list): + raise ValueError("Content must be a list") + + for item in self.content: + if not isinstance( + item, (DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, ToolUseBlock) + ): + raise ValueError( + f"Invalid content type: {type(item)}. Each item must be one of DocumentBlock, " + "GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, or ToolUseBlock" + ) + + def to_dict(self): + res = [] + for item in self.content: + if isinstance(item, str): + res.append({"text": item}) + else: + raise NotImplementedError + return res @property def document(self) -> Optional[DocumentBlock]: @@ -67,3 +86,33 @@ def tool_result(self) -> Optional[ToolResultBlock]: @property def tool_use(self) -> Optional[ToolUseBlock]: return self.content if isinstance(self.content, ToolUseBlock) else None + + +class ConverseRole(str, Enum): + USER = "user" + ASSISTANT = "assistant" + + +class ConverseMessage: + def __init__(self, role: ConverseRole, content: ContentBlock): + self.role = role + self.content = content + + @staticmethod + def from_user( + content: List[ + Union[DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, ToolUseBlock] + ] + ) -> "ConverseMessage": + return ConverseMessage( + ConverseRole.USER, + ContentBlock( + content=content, + ), + ) + + def to_dict(self): + return { + "role": self.role.value, + "content": self.content.to_dict(), + } From bcf91cd0352b66801350dc4f2dbe9fbcc8a416dc Mon Sep 17 00:00:00 2001 From: FloRul Date: Thu, 15 Aug 2024 16:00:34 -0400 Subject: [PATCH 06/35] adding tool config class and clean up from chat copy --- .../converse/converse_generator.py | 38 ++------- .../amazon_bedrock/converse/utils.py | 78 ++++++++++++++++++- 2 files changed, 84 insertions(+), 32 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 3d8a9cf3a..7c7b8d36c 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -80,13 +80,7 @@ def __init__( aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 - # converse parameters - # inference_config: maxTokens, stopSequences, temperature,topP - inference_config: Optional[Dict[str, Any]] = None, - additionalModelRequestFields: Optional[Dict[str, Any]] = None, - tool_config: Optional[Dict[str, Any]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - truncate: Optional[bool] = True, ): """ Initializes the `AmazonBedrockConverseGenerator` with the provided parameters. The parameters are passed to the @@ -128,10 +122,6 @@ def __init__( self.aws_session_token = aws_session_token self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name - self.truncate = truncate - self.inference_config = inference_config - self.additionalModelRequestFields = additionalModelRequestFields - self.tool_config = tool_config # create the AWS session and client def resolve_secret(secret: Optional[Secret]) -> Optional[str]: @@ -155,14 +145,14 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: self.streaming_callback = streaming_callback - @component.output_types(replies=List[ChatMessage]) + @component.output_types(output=ConverseMessage) def run( self, messages: List[ConverseMessage], streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, - inference_config: Optional[Dict[str, Any]] = None, - additionalModelRequestFields: Optional[Dict[str, Any]] = None, - tool_config: Optional[Dict[str, Any]] = None, + inference_config: Dict[str, Any] = {}, + tool_config: Dict[str, Any] = {}, + **kwargs, ): """ Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. @@ -174,9 +164,6 @@ def run( :returns: A dictionary with the following keys: - `replies`: The generated List of `ChatMessage` objects. """ - inference_config = inference_config or {} - inference_config = inference_config.copy() - streaming_callback = streaming_callback or self.streaming_callback # warn and only keep last message if model does not support chat @@ -195,20 +182,16 @@ def run( msg = f"The model {self.model} requires a list of ConverseMessage objects as a prompt." raise ValueError(msg) - body = {"messages": message.to_dict() for message in messages} try: if streaming_callback: - response = self.client.invoke_model_with_response_stream( - body=json.dumps(body), modelId=self.model, accept="application/json", contentType="application/json" - ) - response_stream = response["body"] - replies = self.model_adapter.get_stream_responses( - stream=response_stream, streaming_callback=streaming_callback - ) + raise NotImplementedError else: response = self.client.converse( modelId=self.model, + inferenceConfig=inference_config, messages=[message.to_dict() for message in messages], + toolConfig=tool_config, + **kwargs, ) replies = self.model_adapter.get_responses(response_body=response_body) except ClientError as exception: @@ -216,11 +199,6 @@ def run( raise AmazonBedrockInferenceError(msg) from exception # rename the meta key to be inline with OpenAI meta output keys - for response in replies: - if response.meta is not None and "usage" in response.meta: - response.meta["usage"]["prompt_tokens"] = response.meta["usage"].pop("input_tokens") - response.meta["usage"]["completion_tokens"] = response.meta["usage"].pop("output_tokens") - return {"replies": replies} def to_dict(self) -> Dict[str, Any]: diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 1767542e4..13cbf046f 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -2,6 +2,58 @@ from enum import Enum from typing import List, Union, Optional +from dataclasses import dataclass, field +from typing import List, Dict, Union, Optional + + +@dataclass +class ToolSpec: + name: str + description: Optional[str] = None + inputSchema: Dict[str, Dict] = field(default_factory=dict) + + +@dataclass +class Tool: + toolSpec: ToolSpec + + +@dataclass +class ToolChoice: + auto: Dict = field(default_factory=dict) + any: Dict = field(default_factory=dict) + tool: Optional[Dict[str, str]] = None + + +@dataclass +class ToolConfig: + tools: List[Tool] + toolChoice: Optional[ToolChoice] = None + + def __post_init__(self): + if self.toolChoice and sum(bool(v) for v in vars(self.toolChoice).values()) != 1: + raise ValueError("Only one of 'auto', 'any', or 'tool' can be set in toolChoice") + + if self.toolChoice and self.toolChoice.tool: + if 'name' not in self.toolChoice.tool: + raise ValueError("'name' is required when 'tool' is specified in toolChoice") + + @classmethod + def from_dict(cls, config: Dict) -> 'ToolConfig': + tools = [Tool(ToolSpec(**tool['toolSpec'])) for tool in config.get('tools', [])] + + tool_choice = None + if 'toolChoice' in config: + tc = config['toolChoice'] + if 'auto' in tc: + tool_choice = ToolChoice(auto=tc['auto']) + elif 'any' in tc: + tool_choice = ToolChoice(any=tc['any']) + elif 'tool' in tc: + tool_choice = ToolChoice(tool={'name': tc['tool']['name']}) + + return cls(tools=tools, toolChoice=tool_choice) + @dataclass class DocumentBlock: @@ -101,8 +153,14 @@ def __init__(self, role: ConverseRole, content: ContentBlock): @staticmethod def from_user( content: List[ - Union[DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, ToolUseBlock] - ] + Union[ + DocumentBlock, + GuardrailConverseContentBlock, + ImageBlock, + str, + ToolUseBlock, + ], + ], ) -> "ConverseMessage": return ConverseMessage( ConverseRole.USER, @@ -111,6 +169,22 @@ def from_user( ), ) + @staticmethod + def from_assistant( + content: List[ + Union[ + str, + ToolResultBlock, + ], + ], + ) -> "ConverseMessage": + return ConverseMessage( + ConverseRole.ASSISTANT, + ContentBlock( + content=content, + ), + ) + def to_dict(self): return { "role": self.role.value, From f0f5cda5f990cecdb0f2ff17015ba30c6817537b Mon Sep 17 00:00:00 2001 From: FloRul Date: Sun, 18 Aug 2024 13:15:42 -0400 Subject: [PATCH 07/35] bring tools config to kwargs --- .../converse/converse_generator.py | 26 +++- .../amazon_bedrock/converse/utils.py | 135 +++++++++--------- 2 files changed, 92 insertions(+), 69 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 7c7b8d36c..60d4cbbc1 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -8,7 +8,6 @@ from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable - from haystack_integrations.common.amazon_bedrock.errors import ( AmazonBedrockConfigurationError, AmazonBedrockInferenceError, @@ -145,13 +144,18 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: self.streaming_callback = streaming_callback - @component.output_types(output=ConverseMessage) + @component.output_types( + message=ConverseMessage, + usage=Dict[str, Any], + metrics=Dict[str, Any], + guardrail_trace=Dict[str, Any], + stop_reason=str, + ) def run( self, messages: List[ConverseMessage], streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, inference_config: Dict[str, Any] = {}, - tool_config: Dict[str, Any] = {}, **kwargs, ): """ @@ -190,10 +194,22 @@ def run( modelId=self.model, inferenceConfig=inference_config, messages=[message.to_dict() for message in messages], - toolConfig=tool_config, **kwargs, ) - replies = self.model_adapter.get_responses(response_body=response_body) + output = response.get("output") + if output is None: + raise KeyError + message = output.get("message") + if message is None: + raise KeyError + + return { + "message": ConverseMessage.from_dict(message), + "usage": response.get("usage"), + "metrics": response.get("metrics"), + "guardrail_trace": response.get("trace"), + "stop_reason": response.get("stopReason"), + } except ClientError as exception: msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" raise AmazonBedrockInferenceError(msg) from exception diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 13cbf046f..23f853bdf 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -1,9 +1,6 @@ -from dataclasses import dataclass +from dataclasses import asdict, dataclass, field +from typing import Any, List, Dict, Union, Optional from enum import Enum -from typing import List, Union, Optional - -from dataclasses import dataclass, field -from typing import List, Dict, Union, Optional @dataclass @@ -55,38 +52,60 @@ def from_dict(cls, config: Dict) -> 'ToolConfig': return cls(tools=tools, toolChoice=tool_choice) +@dataclass +class DocumentSource: + bytes: bytes + + @dataclass class DocumentBlock: - # Placeholder for DocumentBlock attributes - pass + format: str + name: str + source: DocumentSource @dataclass class GuardrailConverseContentBlock: - # Placeholder for GuardrailConverseContentBlock attributes - pass + text: str + qualifiers: List[str] = field(default_factory=list) + + +@dataclass +class ImageSource: + bytes: bytes @dataclass class ImageBlock: - # Placeholder for ImageBlock attributes - pass + format: str + source: ImageSource + + +@dataclass +class ToolResultContentBlock: + json: Optional[Dict] = None + text: Optional[str] = None + image: Optional[ImageBlock] = None + document: Optional[DocumentBlock] = None @dataclass class ToolResultBlock: - # Placeholder for ToolResultBlock attributes - pass + toolUseId: str + content: List[ToolResultContentBlock] + status: Optional[str] = None @dataclass class ToolUseBlock: - # Placeholder for ToolUseBlock attributes - pass + toolUseId: str + name: str + input: Dict -from dataclasses import dataclass, asdict -from typing import Union, Optional +class ConverseRole(str, Enum): + USER = "user" + ASSISTANT = "assistant" @dataclass @@ -111,44 +130,25 @@ def to_dict(self): for item in self.content: if isinstance(item, str): res.append({"text": item}) + elif isinstance(item, DocumentBlock): + res.append({"document": asdict(item)}) + elif isinstance(item, GuardrailConverseContentBlock): + res.append({"guardContent": asdict(item)}) + elif isinstance(item, ImageBlock): + res.append({"image": asdict(item)}) + elif isinstance(item, ToolResultBlock): + res.append({"toolResult": asdict(item)}) + elif isinstance(item, ToolUseBlock): + res.append({"toolUse": asdict(item)}) else: - raise NotImplementedError + raise ValueError(f"Unsupported content type: {type(item)}") return res - @property - def document(self) -> Optional[DocumentBlock]: - return self.content if isinstance(self.content, DocumentBlock) else None - - @property - def guard_content(self) -> Optional[GuardrailConverseContentBlock]: - return self.content if isinstance(self.content, GuardrailConverseContentBlock) else None - - @property - def image(self) -> Optional[ImageBlock]: - return self.content if isinstance(self.content, ImageBlock) else None - - @property - def text(self) -> Optional[str]: - return self.content if isinstance(self.content, str) else None - - @property - def tool_result(self) -> Optional[ToolResultBlock]: - return self.content if isinstance(self.content, ToolResultBlock) else None - - @property - def tool_use(self) -> Optional[ToolUseBlock]: - return self.content if isinstance(self.content, ToolUseBlock) else None - - -class ConverseRole(str, Enum): - USER = "user" - ASSISTANT = "assistant" - +@dataclass class ConverseMessage: - def __init__(self, role: ConverseRole, content: ContentBlock): - self.role = role - self.content = content + role: ConverseRole + content: ContentBlock @staticmethod def from_user( @@ -170,20 +170,27 @@ def from_user( ) @staticmethod - def from_assistant( - content: List[ - Union[ - str, - ToolResultBlock, - ], - ], - ) -> "ConverseMessage": - return ConverseMessage( - ConverseRole.ASSISTANT, - ContentBlock( - content=content, - ), - ) + def from_dict(data: Dict[str, Any]) -> "ConverseMessage": + role = ConverseRole(data['role']) + content_blocks = [] + + for item in data['content']: + if 'text' in item: + content_blocks.append(item['text']) + elif 'image' in item: + content_blocks.append(ImageBlock(**item['image'])) + elif 'document' in item: + content_blocks.append(DocumentBlock(**item['document'])) + elif 'toolUse' in item: + content_blocks.append(ToolUseBlock(**item['toolUse'])) + elif 'toolResult' in item: + content_blocks.append(ToolResultBlock(**item['toolResult'])) + elif 'guardContent' in item: + content_blocks.append(GuardrailConverseContentBlock(**item['guardContent'])) + else: + raise ValueError(f"Unknown content type in message: {item}") + + return ConverseMessage(role, ContentBlock(content=content_blocks)) def to_dict(self): return { From 0eddd9d4e57e68f9b857d3ebf88380b0d8080759 Mon Sep 17 00:00:00 2001 From: Florian Rumiel Date: Tue, 20 Aug 2024 09:35:50 -0400 Subject: [PATCH 08/35] first draft for tool calling - add test file --- .../converse/converse_generator.py | 42 ++++++++----- .../amazon_bedrock/converse/test_gen.py | 59 +++++++++++++++++++ .../amazon_bedrock/converse/utils.py | 36 ++++++++++- 3 files changed, 122 insertions(+), 15 deletions(-) create mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/test_gen.py diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 60d4cbbc1..b6fd5709e 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -1,4 +1,3 @@ -import json import logging import re from typing import Any, Callable, ClassVar, Dict, List, Optional, Type @@ -80,6 +79,9 @@ def __init__( aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + # used for pipeline setup + inference_config: Optional[Dict[str, Any]] = None, + tool_config: Optional[Dict[str, Any]] = None, ): """ Initializes the `AmazonBedrockConverseGenerator` with the provided parameters. The parameters are passed to the @@ -135,6 +137,10 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: aws_profile_name=resolve_secret(aws_profile_name), ) self.client = session.client("bedrock-runtime") + + self.inference_config = inference_config + self.tool_config = tool_config + except Exception as exception: msg = ( "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " @@ -156,7 +162,7 @@ def run( messages: List[ConverseMessage], streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, inference_config: Dict[str, Any] = {}, - **kwargs, + tool_config: Optional[Dict[str, Any]] = None, ): """ Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. @@ -190,12 +196,19 @@ def run( if streaming_callback: raise NotImplementedError else: - response = self.client.converse( - modelId=self.model, - inferenceConfig=inference_config, - messages=[message.to_dict() for message in messages], - **kwargs, - ) + # toolConfig is optionnal, so we can add it only if tool_config is not None + request_kwargs = { + "modelId": self.model, + "inferenceConfig": inference_config, + "messages": [message.to_dict() for message in messages], + } + + tool_config = tool_config or self.tool_config + if tool_config is not None: + request_kwargs["toolConfig"] = tool_config + + response = self.client.converse(**request_kwargs) + output = response.get("output") if output is None: raise KeyError @@ -214,9 +227,6 @@ def run( msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" raise AmazonBedrockInferenceError(msg) from exception - # rename the meta key to be inline with OpenAI meta output keys - return {"replies": replies} - def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -233,9 +243,7 @@ def to_dict(self) -> Dict[str, Any]: aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None, aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, - inference_config=self.inference_config, streaming_callback=callback_name, - truncate=self.truncate, ) @classmethod @@ -254,6 +262,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockConverseGenerator": data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) deserialize_secrets_inplace( data["init_parameters"], - ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], + [ + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region_name", + "aws_profile_name", + ], ) return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/test_gen.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/test_gen.py new file mode 100644 index 000000000..f4a48a017 --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/test_gen.py @@ -0,0 +1,59 @@ +import sys +from typing import Dict +from haystack.dataclasses import ChatMessage +from converse_generator import AmazonBedrockConverseGenerator +from utils import ConverseMessage, ToolConfig +from haystack.core.pipeline import Pipeline + + +def get_current_weather(location: str, unit: str = "celsius") -> str: + """Get the current weather in a given location""" + # This is a mock function, replace with actual API call + return f"The weather in {location} is 22 degrees {unit}." + + +def get_current_time(timezone: str) -> str: + """Get the current time in a given timezone""" + # This is a mock function, replace with actual time lookup + return f"The current time in {timezone} is 14:30." + + +def main(): + g = AmazonBedrockConverseGenerator(model="anthropic.claude-3-haiku-20240307-v1:0") + + # Create ToolConfig from functions + tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) + + # Convert ToolConfig to dict for use in the run method + tool_config_dict = tool_config.to_dict() + + print("Tool Config:") + print(tool_config_dict) + + p = Pipeline() + p.add_component("generator", g) + + print("\nRunning pipeline with tools:") + result = p.run( + data={ + "generator": { + "inference_config": { + "temperature": 0.1, + "maxTokens": 256, + "topP": 0.1, + "stopSequences": ["\\n"], + }, + "messages": [ + ConverseMessage.from_user(["What's the weather like in Paris and what time is it in New York?"]), + ], + "tool_config": tool_config_dict, + }, + }, + ) + + print("\nPipeline Result:") + print(result) + + +if __name__ == '__main__': + main() diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 23f853bdf..1f2fbe160 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -1,5 +1,6 @@ from dataclasses import asdict, dataclass, field -from typing import Any, List, Dict, Union, Optional +import inspect +from typing import Any, Callable, List, Dict, Union, Optional from enum import Enum @@ -35,6 +36,25 @@ def __post_init__(self): if 'name' not in self.toolChoice.tool: raise ValueError("'name' is required when 'tool' is specified in toolChoice") + @staticmethod + def from_functions(functions: List[Callable]) -> 'ToolConfig': + tools = [] + for func in functions: + tool_spec = ToolSpec( + name=func.__name__, + description=func.__doc__, + inputSchema={ + "json": { + "type": "object", + "properties": {param: {"type": "string"} for param in inspect.signature(func).parameters}, + "required": list(inspect.signature(func).parameters.keys()), + } + }, + ) + tools.append(Tool(toolSpec=tool_spec)) + + return ToolConfig(tools=tools) + @classmethod def from_dict(cls, config: Dict) -> 'ToolConfig': tools = [Tool(ToolSpec(**tool['toolSpec'])) for tool in config.get('tools', [])] @@ -51,6 +71,19 @@ def from_dict(cls, config: Dict) -> 'ToolConfig': return cls(tools=tools, toolChoice=tool_choice) + def to_dict(self) -> Dict[str, Any]: + result = {"tools": [{"toolSpec": asdict(tool.toolSpec)} for tool in self.tools]} + if self.toolChoice: + tool_choice: Dict[str, Dict[str, Any]] = {} + if self.toolChoice.auto: + tool_choice["auto"] = self.toolChoice.auto + elif self.toolChoice.any: + tool_choice["any"] = self.toolChoice.any + elif self.toolChoice.tool: + tool_choice["tool"] = self.toolChoice.tool + result["toolChoice"] = [tool_choice] + return result + @dataclass class DocumentSource: @@ -159,6 +192,7 @@ def from_user( ImageBlock, str, ToolUseBlock, + ToolResultBlock, ], ], ) -> "ConverseMessage": From 1d70fa5dd0f5d69b1259f52ab9538c91abefd594 Mon Sep 17 00:00:00 2001 From: Florian Rumiel Date: Wed, 21 Aug 2024 23:29:43 -0400 Subject: [PATCH 09/35] tool use streaming --- .../converse/converse_generator.py | 47 ++++++++----- .../amazon_bedrock/converse/test_gen.py | 5 +- .../amazon_bedrock/converse/utils.py | 70 ++++++++++++++++++- 3 files changed, 102 insertions(+), 20 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index b6fd5709e..bde536046 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -12,7 +12,7 @@ AmazonBedrockInferenceError, ) from haystack_integrations.common.amazon_bedrock.utils import get_aws_session -from utils import ConverseMessage +from utils import ConverseMessage, ConverseStreamingChunk, get_stream_message logger = logging.getLogger(__name__) @@ -78,10 +78,10 @@ def __init__( aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, # used for pipeline setup inference_config: Optional[Dict[str, Any]] = None, tool_config: Optional[Dict[str, Any]] = None, + streaming_callback: Optional[Callable[[ConverseStreamingChunk], None]] = None, ): """ Initializes the `AmazonBedrockConverseGenerator` with the provided parameters. The parameters are passed to the @@ -140,6 +140,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: self.inference_config = inference_config self.tool_config = tool_config + self.streaming_callback = streaming_callback except Exception as exception: msg = ( @@ -148,8 +149,6 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: ) raise AmazonBedrockConfigurationError(msg) from exception - self.streaming_callback = streaming_callback - @component.output_types( message=ConverseMessage, usage=Dict[str, Any], @@ -160,7 +159,7 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: def run( self, messages: List[ConverseMessage], - streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + streaming_callback: Optional[Callable[[ConverseStreamingChunk], None]] = None, inference_config: Dict[str, Any] = {}, tool_config: Optional[Dict[str, Any]] = None, ): @@ -192,23 +191,35 @@ def run( msg = f"The model {self.model} requires a list of ConverseMessage objects as a prompt." raise ValueError(msg) + request_kwargs = { + "modelId": self.model, + "inferenceConfig": inference_config, + "messages": [message.to_dict() for message in messages], + } + + tool_config = tool_config or self.tool_config + if tool_config is not None: + request_kwargs["toolConfig"] = tool_config + try: if streaming_callback: - raise NotImplementedError - else: - # toolConfig is optionnal, so we can add it only if tool_config is not None - request_kwargs = { - "modelId": self.model, - "inferenceConfig": inference_config, - "messages": [message.to_dict() for message in messages], + response = self.client.converse_stream(**request_kwargs) + response_stream = response.get("stream") + message, metadata = get_stream_message(stream=response_stream, streaming_callback=streaming_callback) + return { + "message": get_stream_message( + stream=response_stream, + streaming_callback=streaming_callback, + ), + "usage": metadata.get("usage"), + "metrics": metadata.get("metrics"), + "guardrail_trace": metadata.get("trace"), + "stop_reason": metadata.get("stopReason"), } - - tool_config = tool_config or self.tool_config - if tool_config is not None: - request_kwargs["toolConfig"] = tool_config - + else: + # toolConfig is optionnal but the converse api will fail if it is empty, so we can add it only if tool_config is not None response = self.client.converse(**request_kwargs) - + output = response.get("output") if output is None: raise KeyError diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/test_gen.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/test_gen.py index f4a48a017..8ee3b0cf8 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/test_gen.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/test_gen.py @@ -19,7 +19,10 @@ def get_current_time(timezone: str) -> str: def main(): - g = AmazonBedrockConverseGenerator(model="anthropic.claude-3-haiku-20240307-v1:0") + g = AmazonBedrockConverseGenerator( + model="anthropic.claude-3-haiku-20240307-v1:0", + streaming_callback=print, + ) # Create ToolConfig from functions tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 1f2fbe160..6bcb7d61e 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -1,7 +1,9 @@ from dataclasses import asdict, dataclass, field import inspect -from typing import Any, Callable, List, Dict, Union, Optional +import json +from typing import Any, Callable, List, Dict, Tuple, Union, Optional from enum import Enum +from botocore.eventstream import EventStream @dataclass @@ -231,3 +233,69 @@ def to_dict(self): "role": self.role.value, "content": self.content.to_dict(), } + + +@dataclass +class ConverseStreamingChunk: + content: Union[str, ToolUseBlock] + metadata: Dict[str, Any] + index: int = 0 + + +def get_stream_message( + stream: EventStream, + streaming_callback: Callable[[ConverseStreamingChunk], None], +) -> Tuple[ConverseMessage, Dict[str, Any]]: + streaming_chunks: List[ConverseStreamingChunk] = [] + tool_use_dict = {} + str_message = "" + latest_metadata = {} + content_index = 0 # used to keep track of the current str/tool use alternance + current_tool_use_str = "" + for event in stream: + if "contentBlockStart" in event: + if len(current_tool_use_str) > 0 and content_index != event["contentBlockStart"].get("contentBlockIndex"): + tool_use_dict["input"] = current_tool_use_str + current_tool_use_str = "" + + start = event["contentBlockStart"].get("start") + content_index = event["contentBlockStart"].get("contentBlockIndex", content_index) + + if start: + tool_use_dict["toolUseId"] = start["toolUse"]["toolUseId"] + tool_use_dict["name"] = start["toolUse"]["name"] + + if "contentBlockDelta" in event: + delta = event["contentBlockDelta"].get("delta") + + if "text" in delta: + str_message += delta["text"] + if "toolUse" in delta: + current_tool_use_str += delta["toolUse"]["input"] + + if "contentBlockStop" in event: + content_index += 1 # start a new str/tool use alternation + + if "messageStop" in event: + stop_reason = event["messageStop"].get("stopReason") + latest_metadata["stopReason"] = stop_reason + + latest_metadata.update(event.get("metadata", {})) + + block_content = ToolUseBlock(**tool_use_dict) if len(tool_use_dict) == 3 else str_message + + streaming_chunk = ConverseStreamingChunk( + content=block_content, + metadata=event.get("metadata", {}), + index=content_index, + ) + + streaming_callback(streaming_chunk) + streaming_chunks.append(streaming_chunk) + return ( + ConverseMessage( + role=ConverseRole.ASSISTANT, + content=ContentBlock([block_content]), + ), + latest_metadata, + ) From cc373c2f4662ce71b4cdb42afb2427358bc16ffd Mon Sep 17 00:00:00 2001 From: Florian Rumiel Date: Fri, 23 Aug 2024 21:34:55 -0400 Subject: [PATCH 10/35] attempt for tool use 1# --- .../amazon_bedrock/converse/utils.py | 48 ++++++++++++------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 6bcb7d61e..71674422f 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -246,35 +246,50 @@ def get_stream_message( stream: EventStream, streaming_callback: Callable[[ConverseStreamingChunk], None], ) -> Tuple[ConverseMessage, Dict[str, Any]]: - streaming_chunks: List[ConverseStreamingChunk] = [] tool_use_dict = {} str_message = "" latest_metadata = {} - content_index = 0 # used to keep track of the current str/tool use alternance - current_tool_use_str = "" + current_index = 0 # used to keep track of the current str/tool use alternance + + new_tool_use = False + current_tool_use_id = "" + current_tool_use_input = "" + + streaming_blocks = [] + for event in stream: if "contentBlockStart" in event: - if len(current_tool_use_str) > 0 and content_index != event["contentBlockStart"].get("contentBlockIndex"): - tool_use_dict["input"] = current_tool_use_str - current_tool_use_str = "" - - start = event["contentBlockStart"].get("start") - content_index = event["contentBlockStart"].get("contentBlockIndex", content_index) - - if start: - tool_use_dict["toolUseId"] = start["toolUse"]["toolUseId"] - tool_use_dict["name"] = start["toolUse"]["name"] + content_index = event["contentBlockStart"].get("contentBlockIndex") + + # get the start of the tool use + start_of_tool_use = event["contentBlockStart"].get("start") + if event["contentBlockStart"].get("start"): + tool_use_id = start_of_tool_use["toolUse"]["toolUseId"] + tool_use_name = start_of_tool_use["toolUse"]["name"] + + if tool_use_id != current_tool_use_id: + new_tool_use = True + current_tool_use_id = tool_use_id + current_tool_use_input = "" + tool_use_dict = { + "toolUseId": tool_use_id, + "name": tool_use_name, + "input": json.loads(current_tool_use_input), + } + else: + new_tool_use = False if "contentBlockDelta" in event: delta = event["contentBlockDelta"].get("delta") if "text" in delta: str_message += delta["text"] - if "toolUse" in delta: + if "toolUse" in delta and not new_tool_use: current_tool_use_str += delta["toolUse"]["input"] if "contentBlockStop" in event: - content_index += 1 # start a new str/tool use alternation + new_tool_use = False + content_index += 1 # start a new str/tool use alternation if "messageStop" in event: stop_reason = event["messageStop"].get("stopReason") @@ -291,11 +306,10 @@ def get_stream_message( ) streaming_callback(streaming_chunk) - streaming_chunks.append(streaming_chunk) return ( ConverseMessage( role=ConverseRole.ASSISTANT, - content=ContentBlock([block_content]), + content=ContentBlock(streaming_blocks), ), latest_metadata, ) From 231bbd6e789b4b42a41f089ac418adad09807a28 Mon Sep 17 00:00:00 2001 From: Florian Rumiel Date: Fri, 23 Aug 2024 22:49:47 -0400 Subject: [PATCH 11/35] tool use working --- .../converse/converse_generator.py | 5 +- .../amazon_bedrock/converse/utils.py | 121 +++++++++++------- 2 files changed, 79 insertions(+), 47 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index bde536046..3261f0677 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -207,10 +207,7 @@ def run( response_stream = response.get("stream") message, metadata = get_stream_message(stream=response_stream, streaming_callback=streaming_callback) return { - "message": get_stream_message( - stream=response_stream, - streaming_callback=streaming_callback, - ), + "message": message, "usage": metadata.get("usage"), "metrics": metadata.get("metrics"), "guardrail_trace": metadata.get("trace"), diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 71674422f..66a3b8e62 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -1,7 +1,7 @@ from dataclasses import asdict, dataclass, field import inspect import json -from typing import Any, Callable, List, Dict, Tuple, Union, Optional +from typing import Any, Callable, List, Dict, Sequence, Tuple, Union, Optional from enum import Enum from botocore.eventstream import EventStream @@ -135,7 +135,7 @@ class ToolResultBlock: class ToolUseBlock: toolUseId: str name: str - input: Dict + input: Dict[str, Any] class ConverseRole(str, Enum): @@ -160,6 +160,10 @@ def __post_init__(self): "GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, or ToolUseBlock" ) + @staticmethod + def from_assistant(content: Sequence[Union[str, ToolUseBlock]]) -> 'ContentBlock': + return ContentBlock(content=list(content)) + def to_dict(self): res = [] for item in self.content: @@ -233,83 +237,114 @@ def to_dict(self): "role": self.role.value, "content": self.content.to_dict(), } - + @dataclass class ConverseStreamingChunk: content: Union[str, ToolUseBlock] metadata: Dict[str, Any] index: int = 0 + type: str = "" def get_stream_message( stream: EventStream, streaming_callback: Callable[[ConverseStreamingChunk], None], ) -> Tuple[ConverseMessage, Dict[str, Any]]: - tool_use_dict = {} - str_message = "" - latest_metadata = {} - current_index = 0 # used to keep track of the current str/tool use alternance - new_tool_use = False - current_tool_use_id = "" - current_tool_use_input = "" + current_block: Union[str, ToolUseBlock] = "" + current_tool_use_input_str: str = "" + latest_metadata: Dict[str, Any] = {} + event_type: str + current_index: int = 0 # Start with 0 as the first content block seems to be always text + # which never starts with a content block start for some reason... - streaming_blocks = [] + streamed_contents: List[Union[str, ToolUseBlock]] = [] for event in stream: if "contentBlockStart" in event: - content_index = event["contentBlockStart"].get("contentBlockIndex") + event_type = "contentBlockStart" + new_index = event["contentBlockStart"].get("contentBlockIndex", current_index + 1) + + # If index changed, we're starting a new block + if new_index != current_index: + if current_block: + streamed_contents.append(current_block) + current_index = new_index + current_block = "" - # get the start of the tool use start_of_tool_use = event["contentBlockStart"].get("start") - if event["contentBlockStart"].get("start"): - tool_use_id = start_of_tool_use["toolUse"]["toolUseId"] - tool_use_name = start_of_tool_use["toolUse"]["name"] - - if tool_use_id != current_tool_use_id: - new_tool_use = True - current_tool_use_id = tool_use_id - current_tool_use_input = "" - tool_use_dict = { - "toolUseId": tool_use_id, - "name": tool_use_name, - "input": json.loads(current_tool_use_input), - } - else: - new_tool_use = False + if start_of_tool_use: + current_block = ToolUseBlock( + toolUseId=start_of_tool_use["toolUse"]["toolUseId"], + name=start_of_tool_use["toolUse"]["name"], + input={}, + ) if "contentBlockDelta" in event: - delta = event["contentBlockDelta"].get("delta") + event_type = "contentBlockDelta" + delta = event["contentBlockDelta"].get("delta", {}) if "text" in delta: - str_message += delta["text"] - if "toolUse" in delta and not new_tool_use: - current_tool_use_str += delta["toolUse"]["input"] + if isinstance(current_block, str): + current_block += delta["text"] + else: + # If we get text when we expected a tool use, start a new string block + streamed_contents.append(current_block) + current_block = delta["text"] + current_index += 1 + + if "toolUse" in delta: + if isinstance(current_block, ToolUseBlock): + tool_use_input_delta = delta["toolUse"].get("input") + current_tool_use_input_str += tool_use_input_delta + else: + # If we get a tool use when we expected text, start a new ToolUseBlock + streamed_contents.append(current_block) + current_block = ToolUseBlock( + toolUseId=delta["toolUse"]["toolUseId"], + name=delta["toolUse"]["name"], + input=(json.loads(current_tool_use_input_str)), + ) + current_index += 1 if "contentBlockStop" in event: - new_tool_use = False - content_index += 1 # start a new str/tool use alternation + event_type = "contentBlockStop" + if isinstance(current_block, ToolUseBlock): + current_block.input = json.loads(current_tool_use_input_str) + current_tool_use_input_str = "" + streamed_contents.append(current_block) + current_block = "" + current_index += 1 if "messageStop" in event: - stop_reason = event["messageStop"].get("stopReason") - latest_metadata["stopReason"] = stop_reason + event_type = "messageStop" + latest_metadata["stopReason"] = event["messageStop"].get("stopReason") + + if "metadata" in event: + event_type = "metadata" + + if "messageStart" in event: + event_type = "messageStart" latest_metadata.update(event.get("metadata", {})) - block_content = ToolUseBlock(**tool_use_dict) if len(tool_use_dict) == 3 else str_message - streaming_chunk = ConverseStreamingChunk( - content=block_content, - metadata=event.get("metadata", {}), - index=content_index, + content=current_block, + metadata=latest_metadata, + index=current_index, + type=event_type, ) - streaming_callback(streaming_chunk) + + # Add any remaining content + if current_block: + streamed_contents.append(current_block) + return ( ConverseMessage( role=ConverseRole.ASSISTANT, - content=ContentBlock(streaming_blocks), + content=ContentBlock.from_assistant(streamed_contents), ), latest_metadata, ) From cef6ff641fd1cde9784ccb700cb86b9fd4d7fd97 Mon Sep 17 00:00:00 2001 From: Florian Rumiel Date: Fri, 23 Aug 2024 23:22:01 -0400 Subject: [PATCH 12/35] refactor get_stream_response --- .../converse/converse_generator.py | 11 +- .../amazon_bedrock/converse/utils.py | 153 +++++++++--------- 2 files changed, 83 insertions(+), 81 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 3261f0677..4441e88e1 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -78,7 +78,6 @@ def __init__( aws_session_token: Optional[Secret] = Secret.from_env_var(["AWS_SESSION_TOKEN"], strict=False), # noqa: B008 aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 - # used for pipeline setup inference_config: Optional[Dict[str, Any]] = None, tool_config: Optional[Dict[str, Any]] = None, streaming_callback: Optional[Callable[[ConverseStreamingChunk], None]] = None, @@ -100,19 +99,13 @@ def __init__( :param aws_session_token: AWS session token. :param aws_region_name: AWS region name. Make sure the region you set supports Amazon Bedrock. :param aws_profile_name: AWS profile name. - :param generation_kwargs: Keyword arguments sent to the model. These - parameters are specific to a model. You can find them in the [converse documentation](). - :param stop_words: A list of stop words that stop the model from generating more text - when encountered. You can provide them using - this parameter or using the model's `generation_kwargs` under a model's specific key for stop words. - For example, you can provide - stop words for Anthropic Claude in the `stop_sequences` key. :param streaming_callback: A callback function called when a new token is received from the stream. By default, the model is not set up for streaming. To enable streaming, set this parameter to a callback function that handles the streaming chunks. The callback function receives a [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and switches the streaming mode on. - :param truncate: Whether to truncate the prompt messages or not. + :param inference_config: A dictionary containing the inference configuration. The default value is None. + :param tool_config: A dictionary containing the tool configuration. The default value is None. """ if not model: msg = "'model' cannot be None or empty string" diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 66a3b8e62..67e0187d3 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -1,6 +1,7 @@ from dataclasses import asdict, dataclass, field import inspect import json +import logging from typing import Any, Callable, List, Dict, Sequence, Tuple, Union, Optional from enum import Enum from botocore.eventstream import EventStream @@ -237,7 +238,7 @@ def to_dict(self): "role": self.role.value, "content": self.content.to_dict(), } - + @dataclass class ConverseStreamingChunk: @@ -247,95 +248,103 @@ class ConverseStreamingChunk: type: str = "" +@dataclass +class StreamEvent: + type: str + data: Dict[str, Any] + + +def parse_event(event: Dict[str, Any]) -> StreamEvent: + for key in ['contentBlockStart', 'contentBlockDelta', 'contentBlockStop', 'messageStop', 'messageStart']: + if key in event: + return StreamEvent(type=key, data=event[key]) + return StreamEvent(type='metadata', data=event.get('metadata', {})) + + +def handle_content_block_start(event: StreamEvent, current_index: int) -> Tuple[int, Union[str, ToolUseBlock]]: + new_index = event.data.get('contentBlockIndex', current_index + 1) + start_of_tool_use = event.data.get('start') + if start_of_tool_use: + return new_index, ToolUseBlock( + toolUseId=start_of_tool_use['toolUse']['toolUseId'], + name=start_of_tool_use['toolUse']['name'], + input={}, + ) + return new_index, "" + + +def handle_content_block_delta( + event: StreamEvent, current_block: Union[str, ToolUseBlock], current_tool_use_input_str: str +) -> Tuple[Union[str, ToolUseBlock], str]: + delta = event.data.get('delta', {}) + if 'text' in delta: + if isinstance(current_block, str): + return current_block + delta['text'], current_tool_use_input_str + else: + return delta['text'], current_tool_use_input_str + if 'toolUse' in delta: + if isinstance(current_block, ToolUseBlock): + return current_block, current_tool_use_input_str + delta['toolUse'].get('input', '') + else: + return ToolUseBlock( + toolUseId=delta['toolUse']['toolUseId'], + name=delta['toolUse']['name'], + input={}, + ), delta['toolUse'].get('input', '') + return current_block, current_tool_use_input_str + + def get_stream_message( stream: EventStream, streaming_callback: Callable[[ConverseStreamingChunk], None], ) -> Tuple[ConverseMessage, Dict[str, Any]]: - current_block: Union[str, ToolUseBlock] = "" current_tool_use_input_str: str = "" latest_metadata: Dict[str, Any] = {} - event_type: str - current_index: int = 0 # Start with 0 as the first content block seems to be always text - # which never starts with a content block start for some reason... - + current_index: int = 0 streamed_contents: List[Union[str, ToolUseBlock]] = [] - for event in stream: - if "contentBlockStart" in event: - event_type = "contentBlockStart" - new_index = event["contentBlockStart"].get("contentBlockIndex", current_index + 1) + try: + for raw_event in stream: + event = parse_event(raw_event) - # If index changed, we're starting a new block - if new_index != current_index: + if event.type == 'contentBlockStart': if current_block: streamed_contents.append(current_block) - current_index = new_index - current_block = "" - - start_of_tool_use = event["contentBlockStart"].get("start") - if start_of_tool_use: - current_block = ToolUseBlock( - toolUseId=start_of_tool_use["toolUse"]["toolUseId"], - name=start_of_tool_use["toolUse"]["name"], - input={}, - ) + current_index, current_block = handle_content_block_start(event, current_index) - if "contentBlockDelta" in event: - event_type = "contentBlockDelta" - delta = event["contentBlockDelta"].get("delta", {}) - - if "text" in delta: - if isinstance(current_block, str): - current_block += delta["text"] - else: - # If we get text when we expected a tool use, start a new string block + elif event.type == 'contentBlockDelta': + new_block, new_input_str = handle_content_block_delta(event, current_block, current_tool_use_input_str) + if new_block != current_block: streamed_contents.append(current_block) - current_block = delta["text"] current_index += 1 + current_block, current_tool_use_input_str = new_block, new_input_str - if "toolUse" in delta: + elif event.type == 'contentBlockStop': if isinstance(current_block, ToolUseBlock): - tool_use_input_delta = delta["toolUse"].get("input") - current_tool_use_input_str += tool_use_input_delta - else: - # If we get a tool use when we expected text, start a new ToolUseBlock - streamed_contents.append(current_block) - current_block = ToolUseBlock( - toolUseId=delta["toolUse"]["toolUseId"], - name=delta["toolUse"]["name"], - input=(json.loads(current_tool_use_input_str)), - ) - current_index += 1 + current_block.input = json.loads(current_tool_use_input_str) + current_tool_use_input_str = "" + streamed_contents.append(current_block) + current_block = "" + current_index += 1 - if "contentBlockStop" in event: - event_type = "contentBlockStop" - if isinstance(current_block, ToolUseBlock): - current_block.input = json.loads(current_tool_use_input_str) - current_tool_use_input_str = "" - streamed_contents.append(current_block) - current_block = "" - current_index += 1 - - if "messageStop" in event: - event_type = "messageStop" - latest_metadata["stopReason"] = event["messageStop"].get("stopReason") - - if "metadata" in event: - event_type = "metadata" - - if "messageStart" in event: - event_type = "messageStart" - - latest_metadata.update(event.get("metadata", {})) - - streaming_chunk = ConverseStreamingChunk( - content=current_block, - metadata=latest_metadata, - index=current_index, - type=event_type, - ) - streaming_callback(streaming_chunk) + elif event.type == 'messageStop': + latest_metadata["stopReason"] = event.data.get("stopReason") + + latest_metadata.update(event.data if event.type == 'metadata' else {}) + + streaming_chunk = ConverseStreamingChunk( + content=current_block, + metadata=latest_metadata, + index=current_index, + type=event.type, + ) + streaming_callback(streaming_chunk) + + except Exception as e: + # Log the error and re-raise + logging.error(f"Error processing stream: {str(e)}") + raise # Add any remaining content if current_block: From f01af07f6587a65afeef389fca24f8bfd6075f55 Mon Sep 17 00:00:00 2001 From: Florian Rumiel Date: Fri, 23 Aug 2024 23:37:33 -0400 Subject: [PATCH 13/35] fix streaming accumulative response --- .../converse/converse_generator.py | 4 +-- .../amazon_bedrock/converse/utils.py | 30 +++++++++++++++---- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 4441e88e1..c95c1fc61 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -196,6 +196,7 @@ def run( try: if streaming_callback: + response = self.client.converse_stream(**request_kwargs) response_stream = response.get("stream") message, metadata = get_stream_message(stream=response_stream, streaming_callback=streaming_callback) @@ -207,7 +208,6 @@ def run( "stop_reason": metadata.get("stopReason"), } else: - # toolConfig is optionnal but the converse api will fail if it is empty, so we can add it only if tool_config is not None response = self.client.converse(**request_kwargs) output = response.get("output") @@ -225,7 +225,7 @@ def run( "stop_reason": response.get("stopReason"), } except ClientError as exception: - msg = f"Could not inference Amazon Bedrock model {self.model} due: {exception}" + msg = f"Could not run inference on Amazon Bedrock model {self.model} due: {exception}" raise AmazonBedrockInferenceError(msg) from exception def to_dict(self) -> Dict[str, Any]: diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 67e0187d3..486c37592 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -310,13 +310,24 @@ def get_stream_message( if event.type == 'contentBlockStart': if current_block: - streamed_contents.append(current_block) + if isinstance(current_block, str) and streamed_contents and isinstance(streamed_contents[-1], str): + streamed_contents[-1] += current_block + else: + streamed_contents.append(current_block) current_index, current_block = handle_content_block_start(event, current_index) elif event.type == 'contentBlockDelta': new_block, new_input_str = handle_content_block_delta(event, current_block, current_tool_use_input_str) - if new_block != current_block: - streamed_contents.append(current_block) + if isinstance(new_block, ToolUseBlock) and new_block != current_block: + if current_block: + if ( + isinstance(current_block, str) + and streamed_contents + and isinstance(streamed_contents[-1], str) + ): + streamed_contents[-1] += current_block + else: + streamed_contents.append(current_block) current_index += 1 current_block, current_tool_use_input_str = new_block, new_input_str @@ -324,7 +335,12 @@ def get_stream_message( if isinstance(current_block, ToolUseBlock): current_block.input = json.loads(current_tool_use_input_str) current_tool_use_input_str = "" - streamed_contents.append(current_block) + streamed_contents.append(current_block) + elif isinstance(current_block, str): + if streamed_contents and isinstance(streamed_contents[-1], str): + streamed_contents[-1] += current_block + else: + streamed_contents.append(current_block) current_block = "" current_index += 1 @@ -342,13 +358,15 @@ def get_stream_message( streaming_callback(streaming_chunk) except Exception as e: - # Log the error and re-raise logging.error(f"Error processing stream: {str(e)}") raise # Add any remaining content if current_block: - streamed_contents.append(current_block) + if isinstance(current_block, str) and streamed_contents and isinstance(streamed_contents[-1], str): + streamed_contents[-1] += current_block + else: + streamed_contents.append(current_block) return ( ConverseMessage( From c2c8fea0e8b5883bb2be1e932336cfb70aba14c5 Mon Sep 17 00:00:00 2001 From: Florian Rumiel Date: Sat, 24 Aug 2024 00:12:24 -0400 Subject: [PATCH 14/35] validate input params against model capabilities --- .../amazon_bedrock/converse/capabilities.py | 102 +++++++++++++++ .../converse/converse_generator.py | 123 ++++++++++-------- 2 files changed, 168 insertions(+), 57 deletions(-) create mode 100644 integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py new file mode 100644 index 000000000..a6962eb1a --- /dev/null +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py @@ -0,0 +1,102 @@ +from enum import Enum, auto + + +class ModelCapability(Enum): + CONVERSE = auto() + CONVERSE_STREAM = auto() + SYSTEM_PROMPTS = auto() + DOCUMENT_CHAT = auto() + VISION = auto() + TOOL_USE = auto() + STREAMING_TOOL_USE = auto() + GUARDRAILS = auto() + +# https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html +MODEL_CAPABILITIES = { + "ai21.j2-.*-instruct": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + }, + "ai21.j2-.*-text": {ModelCapability.CONVERSE, ModelCapability.GUARDRAILS}, + "amazon.titan-.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.GUARDRAILS, + }, + "amazon.titan-text-express-v1": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.GUARDRAILS, + }, + "anthropic.claude-2.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.GUARDRAILS, + }, + "anthropic.claude-3.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.VISION, + ModelCapability.TOOL_USE, + ModelCapability.STREAMING_TOOL_USE, + ModelCapability.GUARDRAILS, + }, + "cohere.command-text.*": {ModelCapability.CONVERSE, ModelCapability.DOCUMENT_CHAT, ModelCapability.GUARDRAILS}, + "cohere.command-light.*": {ModelCapability.CONVERSE, ModelCapability.GUARDRAILS}, + "cohere.command-r.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.TOOL_USE, + }, + "meta.llama2.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.GUARDRAILS, + }, + "meta.llama3.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.GUARDRAILS, + }, + "meta.llama3-1.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.TOOL_USE, + ModelCapability.GUARDRAILS, + }, + "mistral.mistral-.*-instruct": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.GUARDRAILS, + }, + "mistral.mistral-large.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.TOOL_USE, + ModelCapability.GUARDRAILS, + }, + "mistral.mistral-small.*": { + ModelCapability.CONVERSE, + ModelCapability.CONVERSE_STREAM, + ModelCapability.SYSTEM_PROMPTS, + ModelCapability.TOOL_USE, + ModelCapability.GUARDRAILS, + }, +} diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index c95c1fc61..bf586c44c 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -1,6 +1,7 @@ +from enum import Enum, auto import logging import re -from typing import Any, Callable, ClassVar, Dict, List, Optional, Type +from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Type from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict @@ -12,7 +13,11 @@ AmazonBedrockInferenceError, ) from haystack_integrations.common.amazon_bedrock.utils import get_aws_session -from utils import ConverseMessage, ConverseStreamingChunk, get_stream_message +from capabilities import ( + ModelCapability, + MODEL_CAPABILITIES, +) +from utils import ConverseMessage, ConverseRole, ConverseStreamingChunk, ImageBlock, ToolConfig, get_stream_message logger = logging.getLogger(__name__) @@ -56,17 +61,6 @@ class AmazonBedrockConverseGenerator: """ # according to the list provided in the toolConfig arg: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestSyntax - SUPPORTED_TOOL_MODEL_PATTERNS: ClassVar[List[str]] = [ - r"anthropic.claude-3.*", - r"cohere.command-r.*", - r"mistral.mistral-large.*", - ] - - UNSUPPORTED_CHAT_MODEL_PATTERNS: ClassVar[List[str]] = [ - r"cohere.command-text.*", - r"cohere.command-light.*", - r"ai21.j2.*", - ] def __init__( self, @@ -81,6 +75,7 @@ def __init__( inference_config: Optional[Dict[str, Any]] = None, tool_config: Optional[Dict[str, Any]] = None, streaming_callback: Optional[Callable[[ConverseStreamingChunk], None]] = None, + system_prompt: Optional[List[Dict[str, Any]]] = None, ): """ Initializes the `AmazonBedrockConverseGenerator` with the provided parameters. The parameters are passed to the @@ -134,6 +129,8 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: self.inference_config = inference_config self.tool_config = tool_config self.streaming_callback = streaming_callback + self.system_prompt = system_prompt + self.model_capabilities = self._get_model_capabilities(model) except Exception as exception: msg = ( @@ -142,6 +139,12 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: ) raise AmazonBedrockConfigurationError(msg) from exception + def _get_model_capabilities(self, model: str) -> Set[ModelCapability]: + for pattern, capabilities in MODEL_CAPABILITIES.items(): + if re.match(pattern, model): + return capabilities + raise ValueError(f"Unsupported model: {model}") + @component.output_types( message=ConverseMessage, usage=Dict[str, Any], @@ -154,78 +157,84 @@ def run( messages: List[ConverseMessage], streaming_callback: Optional[Callable[[ConverseStreamingChunk], None]] = None, inference_config: Dict[str, Any] = {}, - tool_config: Optional[Dict[str, Any]] = None, + tool_config: Optional[ToolConfig] = None, + system_prompt: Optional[List[Dict[str, Any]]] = None, ): - """ - Generates a list of `ChatMessage` response to the given messages using the Amazon Bedrock LLM. - - :param messages: The messages to generate a response to. - :param streaming_callback: - A callback function that is called when a new token is received from the stream. - :param generation_kwargs: Additional generation keyword arguments passed to the model. - :returns: A dictionary with the following keys: - - `replies`: The generated List of `ChatMessage` objects. - """ streaming_callback = streaming_callback or self.streaming_callback + system_prompt = system_prompt or self.system_prompt - # warn and only keep last message if model does not support chat - if re.match("|".join(self.UNSUPPORTED_CHAT_MODEL_PATTERNS), self.model) and len(messages) > 1: - logging.warning( - f"The model {self.model} does not support chat. Only the last message " "will be taken into account." - ) - messages = messages[-1:] - - # check if the prompt is a list of ConverseMessage objects if not ( isinstance(messages, list) and len(messages) > 0 and all(isinstance(message, ConverseMessage) for message in messages) ): - msg = f"The model {self.model} requires a list of ConverseMessage objects as a prompt." + msg = f"The model {self.model} requires a list of ConverseMessage objects as input." raise ValueError(msg) + # Check and filter messages based on model capabilities + if ModelCapability.SYSTEM_PROMPTS not in self.model_capabilities and system_prompt: + logger.warning( + f"The model {self.model} does not support system prompts. The provided system_prompt will be ignored." + ) + system_prompt = None + + if ModelCapability.VISION not in self.model_capabilities: + for msg in messages: + msg.content.content = [item for item in msg.content.content if not isinstance(item, ImageBlock)] + if any(isinstance(item, ImageBlock) for msg in messages for item in msg.content.content): + logger.warning(f"The model {self.model} does not support vision. Image content has been removed.") + + if ModelCapability.DOCUMENT_CHAT not in self.model_capabilities: + logger.warning( + f"The model {self.model} does not support document chat. This feature will not be available." + ) + + if ModelCapability.TOOL_USE not in self.model_capabilities and tool_config: + logger.warning(f"The model {self.model} does not support tools. The provided tool_config will be ignored.") + tool_config = None + + if ModelCapability.STREAMING_TOOL_USE not in self.model_capabilities and streaming_callback and tool_config: + logger.warning( + f"The model {self.model} does not support streaming tool use. Streaming will be disabled for tool calls." + ) + request_kwargs = { "modelId": self.model, "inferenceConfig": inference_config, "messages": [message.to_dict() for message in messages], } - tool_config = tool_config or self.tool_config - if tool_config is not None: + if tool_config: request_kwargs["toolConfig"] = tool_config try: - if streaming_callback: - + if streaming_callback and ModelCapability.CONVERSE_STREAM in self.model_capabilities: response = self.client.converse_stream(**request_kwargs) response_stream = response.get("stream") message, metadata = get_stream_message(stream=response_stream, streaming_callback=streaming_callback) - return { - "message": message, - "usage": metadata.get("usage"), - "metrics": metadata.get("metrics"), - "guardrail_trace": metadata.get("trace"), - "stop_reason": metadata.get("stopReason"), - } else: response = self.client.converse(**request_kwargs) - output = response.get("output") if output is None: - raise KeyError + raise KeyError("Response does not contain 'output'") message = output.get("message") if message is None: - raise KeyError - - return { - "message": ConverseMessage.from_dict(message), - "usage": response.get("usage"), - "metrics": response.get("metrics"), - "guardrail_trace": response.get("trace"), - "stop_reason": response.get("stopReason"), - } + raise KeyError("Response 'output' does not contain 'message'") + message = ConverseMessage.from_dict(message) + metadata = response + + return { + "message": message, + "usage": metadata.get("usage"), + "metrics": metadata.get("metrics"), + "guardrail_trace": ( + metadata.get("trace") if ModelCapability.GUARDRAILS in self.model_capabilities else None + ), + "stop_reason": metadata.get("stopReason"), + } + except ClientError as exception: - msg = f"Could not run inference on Amazon Bedrock model {self.model} due: {exception}" + msg = f"Could not run inference on Amazon Bedrock model {self.model} due to: {exception}" raise AmazonBedrockInferenceError(msg) from exception def to_dict(self) -> Dict[str, Any]: @@ -255,7 +264,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockConverseGenerator": :param data: Dictionary to deserialize from. :returns: - Deserialized component. + Deserialized component. """ init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") From 0a4e8c02a9222e388020d07bfa2552d88ddd9c42 Mon Sep 17 00:00:00 2001 From: Florian Rumiel Date: Sat, 24 Aug 2024 09:04:10 -0400 Subject: [PATCH 15/35] format check in document block --- .../components/generators/amazon_bedrock/converse/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 486c37592..cb81cfd95 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -2,7 +2,7 @@ import inspect import json import logging -from typing import Any, Callable, List, Dict, Sequence, Tuple, Union, Optional +from typing import Any, Callable, List, Dict, Literal, Sequence, Tuple, Union, Optional from enum import Enum from botocore.eventstream import EventStream @@ -95,9 +95,10 @@ class DocumentSource: @dataclass class DocumentBlock: - format: str + SUPPORTED_FORMATS = Literal['pdf', 'csv', 'doc', 'docx', 'xls', 'xlsx', 'html', 'txt', 'md'] + format: SUPPORTED_FORMATS name: str - source: DocumentSource + source: bytes @dataclass From 8dac6f540f48a680f4f014aed3c9586862fd6acb Mon Sep 17 00:00:00 2001 From: Florian Rumiel Date: Mon, 26 Aug 2024 08:31:50 -0400 Subject: [PATCH 16/35] move files to proper folder --- .../converse_generator_example.py} | 15 ++- .../generators/amazon_bedrock/__init__.py | 3 +- .../converse/converse_generator.py | 4 +- .../tests/test_converse_generator.py | 93 +++++++++++++++++++ 4 files changed, 102 insertions(+), 13 deletions(-) rename integrations/amazon_bedrock/{src/haystack_integrations/components/generators/amazon_bedrock/converse/test_gen.py => examples/converse_generator_example.py} (82%) create mode 100644 integrations/amazon_bedrock/tests/test_converse_generator.py diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/test_gen.py b/integrations/amazon_bedrock/examples/converse_generator_example.py similarity index 82% rename from integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/test_gen.py rename to integrations/amazon_bedrock/examples/converse_generator_example.py index 8ee3b0cf8..61922b1a1 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/test_gen.py +++ b/integrations/amazon_bedrock/examples/converse_generator_example.py @@ -1,7 +1,4 @@ -import sys -from typing import Dict -from haystack.dataclasses import ChatMessage -from converse_generator import AmazonBedrockConverseGenerator +from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockConverseGenerator from utils import ConverseMessage, ToolConfig from haystack.core.pipeline import Pipeline @@ -19,8 +16,8 @@ def get_current_time(timezone: str) -> str: def main(): - g = AmazonBedrockConverseGenerator( - model="anthropic.claude-3-haiku-20240307-v1:0", + generator = AmazonBedrockConverseGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", streaming_callback=print, ) @@ -33,11 +30,11 @@ def main(): print("Tool Config:") print(tool_config_dict) - p = Pipeline() - p.add_component("generator", g) + pipeline = Pipeline() + pipeline.add_component("generator", generator) print("\nRunning pipeline with tools:") - result = p.run( + result = pipeline.run( data={ "generator": { "inference_config": { diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py index 2d33beb42..5ef74f997 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py @@ -3,5 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 from .chat.chat_generator import AmazonBedrockChatGenerator from .generator import AmazonBedrockGenerator +from .converse.converse_generator import AmazonBedrockConverseGenerator -__all__ = ["AmazonBedrockGenerator", "AmazonBedrockChatGenerator"] +__all__ = ["AmazonBedrockGenerator", "AmazonBedrockChatGenerator", "AmazonBedrockConverseGenerator"] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index bf586c44c..24ef38f75 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -1,11 +1,9 @@ -from enum import Enum, auto import logging import re -from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Type +from typing import Any, Callable, Dict, List, Optional, Set from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict -from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from haystack_integrations.common.amazon_bedrock.errors import ( diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py new file mode 100644 index 000000000..57512e997 --- /dev/null +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -0,0 +1,93 @@ +import pytest +from unittest.mock import Mock, patch +from haystack.dataclasses import ChatMessage +from converse_generator import AmazonBedrockConverseGenerator, AmazonBedrockConfigurationError +from utils import ConverseMessage, ToolConfig, ConverseRole, ContentBlock +from capabilities import ModelCapability, MODEL_CAPABILITIES + + +@pytest.fixture +def generator(): + model = "anthropic.claude-3-haiku-20240307-v1:0" + return AmazonBedrockConverseGenerator(model=model, streaming_callback=print) + + +def test_init(generator): + assert generator.model == "anthropic.claude-3-haiku-20240307-v1:0" + assert generator.streaming_callback == print + assert generator.client is not None + + +def test_get_model_capabilities(generator): + capabilities = generator._get_model_capabilities(generator.model) + expected_capabilities = MODEL_CAPABILITIES["anthropic.claude-3.*"] + assert capabilities == expected_capabilities + + +def test_get_model_capabilities_unsupported_model(generator): + with pytest.raises(ValueError): + generator._get_model_capabilities("unsupported_model") + + +@patch('converse_generator.get_aws_session') +def test_init_aws_error(mock_get_aws_session): + mock_get_aws_session.side_effect = Exception("AWS Error") + with pytest.raises(AmazonBedrockConfigurationError): + AmazonBedrockConverseGenerator(model="anthropic.claude-3-haiku-20240307-v1:0") + + +@patch.object(AmazonBedrockConverseGenerator, 'client') +def test_run_streaming(mock_client, generator): + mock_stream = Mock() + mock_client.converse_stream.return_value = {"stream": mock_stream} + + messages = [ConverseMessage.from_user(["Hello"])] + streaming_callback = Mock() + + result = generator.run(messages, streaming_callback=streaming_callback) + + mock_client.converse_stream.assert_called_once() + assert "message" in result + assert "usage" in result + assert "metrics" in result + assert "guardrail_trace" in result + assert "stop_reason" in result + + +@patch.object(AmazonBedrockConverseGenerator, 'client') +def test_run_non_streaming(mock_client, generator): + mock_response = { + "output": {"message": {"role": "assistant", "content": [{"text": "Hello, how can I help you?"}]}}, + "usage": {"inputTokens": 10, "outputTokens": 20}, + "metrics": {"firstByteLatency": 0.5}, + } + mock_client.converse.return_value = mock_response + + messages = [ConverseMessage.from_user(["Hello"])] + + result = generator.run(messages) + + mock_client.converse.assert_called_once() + assert isinstance(result["message"], ConverseMessage) + assert result["usage"] == mock_response["usage"] + assert result["metrics"] == mock_response["metrics"] + + +def test_run_invalid_messages(generator): + with pytest.raises(ValueError): + generator.run(["invalid message"]) + + +def test_to_dict(generator): + serialized = generator.to_dict() + assert "model" in serialized + assert serialized["model"] == generator.model + + +def test_from_dict(): + data = { + "init_parameters": {"model": "anthropic.claude-3-haiku-20240307-v1:0", "streaming_callback": "builtins.print"} + } + deserialized = AmazonBedrockConverseGenerator.from_dict(data) + assert deserialized.model == "anthropic.claude-3-haiku-20240307-v1:0" + assert deserialized.streaming_callback == print From 98efefc6168e48b929b9bad499d3f363e3807570 Mon Sep 17 00:00:00 2001 From: FloRul Date: Mon, 26 Aug 2024 22:21:05 -0400 Subject: [PATCH 17/35] fix imports --- .../examples/converse_generator_example.py | 9 +- .../generators/amazon_bedrock/__init__.py | 2 +- .../amazon_bedrock/converse/capabilities.py | 1 + .../converse/converse_generator.py | 11 +- .../amazon_bedrock/converse/utils.py | 105 +++++++++--------- .../tests/test_converse_generator.py | 17 +-- 6 files changed, 76 insertions(+), 69 deletions(-) diff --git a/integrations/amazon_bedrock/examples/converse_generator_example.py b/integrations/amazon_bedrock/examples/converse_generator_example.py index 61922b1a1..8fa8d2181 100644 --- a/integrations/amazon_bedrock/examples/converse_generator_example.py +++ b/integrations/amazon_bedrock/examples/converse_generator_example.py @@ -1,6 +1,9 @@ -from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockConverseGenerator -from utils import ConverseMessage, ToolConfig -from haystack.core.pipeline import Pipeline + + + +from haystack import Pipeline +from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockConverseGenerator +from haystack_integrations.components.generators.amazon_bedrock.converse.utils import ConverseMessage,ToolConfig def get_current_weather(location: str, unit: str = "celsius") -> str: diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py index 5ef74f997..3ae2af9df 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 from .chat.chat_generator import AmazonBedrockChatGenerator -from .generator import AmazonBedrockGenerator from .converse.converse_generator import AmazonBedrockConverseGenerator +from .generator import AmazonBedrockGenerator __all__ = ["AmazonBedrockGenerator", "AmazonBedrockChatGenerator", "AmazonBedrockConverseGenerator"] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py index a6962eb1a..4c69d94c5 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py @@ -11,6 +11,7 @@ class ModelCapability(Enum): STREAMING_TOOL_USE = auto() GUARDRAILS = auto() + # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html MODEL_CAPABILITIES = { "ai21.j2-.*-instruct": { diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 24ef38f75..eaa0e6b77 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -3,19 +3,20 @@ from typing import Any, Callable, Dict, List, Optional, Set from botocore.exceptions import ClientError +from .capabilities import ( + MODEL_CAPABILITIES, + ModelCapability, +) from haystack import component, default_from_dict, default_to_dict from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable +from .utils import ConverseMessage, ConverseStreamingChunk, ImageBlock, ToolConfig, get_stream_message + from haystack_integrations.common.amazon_bedrock.errors import ( AmazonBedrockConfigurationError, AmazonBedrockInferenceError, ) from haystack_integrations.common.amazon_bedrock.utils import get_aws_session -from capabilities import ( - ModelCapability, - MODEL_CAPABILITIES, -) -from utils import ConverseMessage, ConverseRole, ConverseStreamingChunk, ImageBlock, ToolConfig, get_stream_message logger = logging.getLogger(__name__) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index cb81cfd95..808d0d37a 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -1,9 +1,10 @@ -from dataclasses import asdict, dataclass, field -import inspect +import inspect import json import logging -from typing import Any, Callable, List, Dict, Literal, Sequence, Tuple, Union, Optional +from dataclasses import asdict, dataclass, field from enum import Enum +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union + from botocore.eventstream import EventStream @@ -36,11 +37,11 @@ def __post_init__(self): raise ValueError("Only one of 'auto', 'any', or 'tool' can be set in toolChoice") if self.toolChoice and self.toolChoice.tool: - if 'name' not in self.toolChoice.tool: + if "name" not in self.toolChoice.tool: raise ValueError("'name' is required when 'tool' is specified in toolChoice") @staticmethod - def from_functions(functions: List[Callable]) -> 'ToolConfig': + def from_functions(functions: List[Callable]) -> "ToolConfig": tools = [] for func in functions: tool_spec = ToolSpec( @@ -59,18 +60,18 @@ def from_functions(functions: List[Callable]) -> 'ToolConfig': return ToolConfig(tools=tools) @classmethod - def from_dict(cls, config: Dict) -> 'ToolConfig': - tools = [Tool(ToolSpec(**tool['toolSpec'])) for tool in config.get('tools', [])] + def from_dict(cls, config: Dict) -> "ToolConfig": + tools = [Tool(ToolSpec(**tool["toolSpec"])) for tool in config.get("tools", [])] tool_choice = None - if 'toolChoice' in config: - tc = config['toolChoice'] - if 'auto' in tc: - tool_choice = ToolChoice(auto=tc['auto']) - elif 'any' in tc: - tool_choice = ToolChoice(any=tc['any']) - elif 'tool' in tc: - tool_choice = ToolChoice(tool={'name': tc['tool']['name']}) + if "toolChoice" in config: + tc = config["toolChoice"] + if "auto" in tc: + tool_choice = ToolChoice(auto=tc["auto"]) + elif "any" in tc: + tool_choice = ToolChoice(any=tc["any"]) + elif "tool" in tc: + tool_choice = ToolChoice(tool={"name": tc["tool"]["name"]}) return cls(tools=tools, toolChoice=tool_choice) @@ -95,7 +96,7 @@ class DocumentSource: @dataclass class DocumentBlock: - SUPPORTED_FORMATS = Literal['pdf', 'csv', 'doc', 'docx', 'xls', 'xlsx', 'html', 'txt', 'md'] + SUPPORTED_FORMATS = Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] format: SUPPORTED_FORMATS name: str source: bytes @@ -163,7 +164,7 @@ def __post_init__(self): ) @staticmethod - def from_assistant(content: Sequence[Union[str, ToolUseBlock]]) -> 'ContentBlock': + def from_assistant(content: Sequence[Union[str, ToolUseBlock]]) -> "ContentBlock": return ContentBlock(content=list(content)) def to_dict(self): @@ -213,22 +214,22 @@ def from_user( @staticmethod def from_dict(data: Dict[str, Any]) -> "ConverseMessage": - role = ConverseRole(data['role']) + role = ConverseRole(data["role"]) content_blocks = [] - for item in data['content']: - if 'text' in item: - content_blocks.append(item['text']) - elif 'image' in item: - content_blocks.append(ImageBlock(**item['image'])) - elif 'document' in item: - content_blocks.append(DocumentBlock(**item['document'])) - elif 'toolUse' in item: - content_blocks.append(ToolUseBlock(**item['toolUse'])) - elif 'toolResult' in item: - content_blocks.append(ToolResultBlock(**item['toolResult'])) - elif 'guardContent' in item: - content_blocks.append(GuardrailConverseContentBlock(**item['guardContent'])) + for item in data["content"]: + if "text" in item: + content_blocks.append(item["text"]) + elif "image" in item: + content_blocks.append(ImageBlock(**item["image"])) + elif "document" in item: + content_blocks.append(DocumentBlock(**item["document"])) + elif "toolUse" in item: + content_blocks.append(ToolUseBlock(**item["toolUse"])) + elif "toolResult" in item: + content_blocks.append(ToolResultBlock(**item["toolResult"])) + elif "guardContent" in item: + content_blocks.append(GuardrailConverseContentBlock(**item["guardContent"])) else: raise ValueError(f"Unknown content type in message: {item}") @@ -256,19 +257,19 @@ class StreamEvent: def parse_event(event: Dict[str, Any]) -> StreamEvent: - for key in ['contentBlockStart', 'contentBlockDelta', 'contentBlockStop', 'messageStop', 'messageStart']: + for key in ["contentBlockStart", "contentBlockDelta", "contentBlockStop", "messageStop", "messageStart"]: if key in event: return StreamEvent(type=key, data=event[key]) - return StreamEvent(type='metadata', data=event.get('metadata', {})) + return StreamEvent(type="metadata", data=event.get("metadata", {})) def handle_content_block_start(event: StreamEvent, current_index: int) -> Tuple[int, Union[str, ToolUseBlock]]: - new_index = event.data.get('contentBlockIndex', current_index + 1) - start_of_tool_use = event.data.get('start') + new_index = event.data.get("contentBlockIndex", current_index + 1) + start_of_tool_use = event.data.get("start") if start_of_tool_use: return new_index, ToolUseBlock( - toolUseId=start_of_tool_use['toolUse']['toolUseId'], - name=start_of_tool_use['toolUse']['name'], + toolUseId=start_of_tool_use["toolUse"]["toolUseId"], + name=start_of_tool_use["toolUse"]["name"], input={}, ) return new_index, "" @@ -277,21 +278,21 @@ def handle_content_block_start(event: StreamEvent, current_index: int) -> Tuple[ def handle_content_block_delta( event: StreamEvent, current_block: Union[str, ToolUseBlock], current_tool_use_input_str: str ) -> Tuple[Union[str, ToolUseBlock], str]: - delta = event.data.get('delta', {}) - if 'text' in delta: + delta = event.data.get("delta", {}) + if "text" in delta: if isinstance(current_block, str): - return current_block + delta['text'], current_tool_use_input_str + return current_block + delta["text"], current_tool_use_input_str else: - return delta['text'], current_tool_use_input_str - if 'toolUse' in delta: + return delta["text"], current_tool_use_input_str + if "toolUse" in delta: if isinstance(current_block, ToolUseBlock): - return current_block, current_tool_use_input_str + delta['toolUse'].get('input', '') + return current_block, current_tool_use_input_str + delta["toolUse"].get("input", "") else: return ToolUseBlock( - toolUseId=delta['toolUse']['toolUseId'], - name=delta['toolUse']['name'], + toolUseId=delta["toolUse"]["toolUseId"], + name=delta["toolUse"]["name"], input={}, - ), delta['toolUse'].get('input', '') + ), delta["toolUse"].get("input", "") return current_block, current_tool_use_input_str @@ -309,7 +310,7 @@ def get_stream_message( for raw_event in stream: event = parse_event(raw_event) - if event.type == 'contentBlockStart': + if event.type == "contentBlockStart": if current_block: if isinstance(current_block, str) and streamed_contents and isinstance(streamed_contents[-1], str): streamed_contents[-1] += current_block @@ -317,7 +318,7 @@ def get_stream_message( streamed_contents.append(current_block) current_index, current_block = handle_content_block_start(event, current_index) - elif event.type == 'contentBlockDelta': + elif event.type == "contentBlockDelta": new_block, new_input_str = handle_content_block_delta(event, current_block, current_tool_use_input_str) if isinstance(new_block, ToolUseBlock) and new_block != current_block: if current_block: @@ -332,7 +333,7 @@ def get_stream_message( current_index += 1 current_block, current_tool_use_input_str = new_block, new_input_str - elif event.type == 'contentBlockStop': + elif event.type == "contentBlockStop": if isinstance(current_block, ToolUseBlock): current_block.input = json.loads(current_tool_use_input_str) current_tool_use_input_str = "" @@ -345,10 +346,10 @@ def get_stream_message( current_block = "" current_index += 1 - elif event.type == 'messageStop': + elif event.type == "messageStop": latest_metadata["stopReason"] = event.data.get("stopReason") - latest_metadata.update(event.data if event.type == 'metadata' else {}) + latest_metadata.update(event.data if event.type == "metadata" else {}) streaming_chunk = ConverseStreamingChunk( content=current_block, @@ -359,7 +360,7 @@ def get_stream_message( streaming_callback(streaming_chunk) except Exception as e: - logging.error(f"Error processing stream: {str(e)}") + logging.error(f"Error processing stream: {e!s}") raise # Add any remaining content diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index 57512e997..a3799185f 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -1,9 +1,10 @@ -import pytest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch + +import pytest +from capabilities import MODEL_CAPABILITIES, ModelCapability +from converse_generator import AmazonBedrockConfigurationError, AmazonBedrockConverseGenerator from haystack.dataclasses import ChatMessage -from converse_generator import AmazonBedrockConverseGenerator, AmazonBedrockConfigurationError -from utils import ConverseMessage, ToolConfig, ConverseRole, ContentBlock -from capabilities import ModelCapability, MODEL_CAPABILITIES +from utils import ContentBlock, ConverseMessage, ConverseRole, ToolConfig @pytest.fixture @@ -29,14 +30,14 @@ def test_get_model_capabilities_unsupported_model(generator): generator._get_model_capabilities("unsupported_model") -@patch('converse_generator.get_aws_session') +@patch("converse_generator.get_aws_session") def test_init_aws_error(mock_get_aws_session): mock_get_aws_session.side_effect = Exception("AWS Error") with pytest.raises(AmazonBedrockConfigurationError): AmazonBedrockConverseGenerator(model="anthropic.claude-3-haiku-20240307-v1:0") -@patch.object(AmazonBedrockConverseGenerator, 'client') +@patch.object(AmazonBedrockConverseGenerator, "client") def test_run_streaming(mock_client, generator): mock_stream = Mock() mock_client.converse_stream.return_value = {"stream": mock_stream} @@ -54,7 +55,7 @@ def test_run_streaming(mock_client, generator): assert "stop_reason" in result -@patch.object(AmazonBedrockConverseGenerator, 'client') +@patch.object(AmazonBedrockConverseGenerator, "client") def test_run_non_streaming(mock_client, generator): mock_response = { "output": {"message": {"role": "assistant", "content": [{"text": "Hello, how can I help you?"}]}}, From 21081a12010b5b57b5fe5caca150e12b868ac9b2 Mon Sep 17 00:00:00 2001 From: FloRul Date: Sun, 1 Sep 2024 00:09:38 -0400 Subject: [PATCH 18/35] set up tests and fix module import paths --- integrations/amazon_bedrock/pydoc/config.yml | 1 + .../generators/amazon_bedrock/__init__.py | 11 +- .../converse/converse_generator.py | 9 +- .../tests/test_converse_generator.py | 188 +++++++++--------- 4 files changed, 117 insertions(+), 92 deletions(-) diff --git a/integrations/amazon_bedrock/pydoc/config.yml b/integrations/amazon_bedrock/pydoc/config.yml index 6cb05d6f3..d84f0bac9 100644 --- a/integrations/amazon_bedrock/pydoc/config.yml +++ b/integrations/amazon_bedrock/pydoc/config.yml @@ -7,6 +7,7 @@ loaders: "haystack_integrations.common.amazon_bedrock.errors", "haystack_integrations.components.generators.amazon_bedrock.handlers", "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator", + "haystack_integrations.components.generators.amazon_bedrock.converse.converse_generator", "haystack_integrations.components.embedders.amazon_bedrock.text_embedder", "haystack_integrations.components.embedders.amazon_bedrock.document_embedder", ] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py index 3ae2af9df..f74e8a7ce 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py @@ -4,5 +4,14 @@ from .chat.chat_generator import AmazonBedrockChatGenerator from .converse.converse_generator import AmazonBedrockConverseGenerator from .generator import AmazonBedrockGenerator +from .converse.utils import ConverseMessage, ToolConfig +from .converse.capabilities import MODEL_CAPABILITIES -__all__ = ["AmazonBedrockGenerator", "AmazonBedrockChatGenerator", "AmazonBedrockConverseGenerator"] +__all__ = [ + "AmazonBedrockGenerator", + "AmazonBedrockChatGenerator", + "AmazonBedrockConverseGenerator", + "ConverseMessage", + "ToolConfig", + "MODEL_CAPABILITIES", +] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index eaa0e6b77..3704bd81a 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -72,7 +72,7 @@ def __init__( aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008 aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 inference_config: Optional[Dict[str, Any]] = None, - tool_config: Optional[Dict[str, Any]] = None, + tool_config: Optional[ToolConfig] = None, streaming_callback: Optional[Callable[[ConverseStreamingChunk], None]] = None, system_prompt: Optional[List[Dict[str, Any]]] = None, ): @@ -205,6 +205,10 @@ def run( if tool_config: request_kwargs["toolConfig"] = tool_config + if system_prompt: + request_kwargs["system"] = { + "text": system_prompt, + } try: if streaming_callback and ModelCapability.CONVERSE_STREAM in self.model_capabilities: @@ -253,6 +257,9 @@ def to_dict(self) -> Dict[str, Any]: aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None, model=self.model, streaming_callback=callback_name, + system_prompt=self.system_prompt, + inference_config=self.inference_config, + tool_config=self.tool_config.to_dict() if self.tool_config else None, ) @classmethod diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index a3799185f..4f07f8373 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -1,94 +1,102 @@ -from unittest.mock import Mock, patch - +from unittest.mock import Mock, patch, MagicMock import pytest -from capabilities import MODEL_CAPABILITIES, ModelCapability -from converse_generator import AmazonBedrockConfigurationError, AmazonBedrockConverseGenerator -from haystack.dataclasses import ChatMessage -from utils import ContentBlock, ConverseMessage, ConverseRole, ToolConfig - - -@pytest.fixture -def generator(): - model = "anthropic.claude-3-haiku-20240307-v1:0" - return AmazonBedrockConverseGenerator(model=model, streaming_callback=print) - - -def test_init(generator): - assert generator.model == "anthropic.claude-3-haiku-20240307-v1:0" - assert generator.streaming_callback == print - assert generator.client is not None - - -def test_get_model_capabilities(generator): - capabilities = generator._get_model_capabilities(generator.model) - expected_capabilities = MODEL_CAPABILITIES["anthropic.claude-3.*"] - assert capabilities == expected_capabilities - - -def test_get_model_capabilities_unsupported_model(generator): - with pytest.raises(ValueError): - generator._get_model_capabilities("unsupported_model") - - -@patch("converse_generator.get_aws_session") -def test_init_aws_error(mock_get_aws_session): - mock_get_aws_session.side_effect = Exception("AWS Error") - with pytest.raises(AmazonBedrockConfigurationError): - AmazonBedrockConverseGenerator(model="anthropic.claude-3-haiku-20240307-v1:0") - - -@patch.object(AmazonBedrockConverseGenerator, "client") -def test_run_streaming(mock_client, generator): - mock_stream = Mock() - mock_client.converse_stream.return_value = {"stream": mock_stream} - messages = [ConverseMessage.from_user(["Hello"])] - streaming_callback = Mock() - - result = generator.run(messages, streaming_callback=streaming_callback) - - mock_client.converse_stream.assert_called_once() - assert "message" in result - assert "usage" in result - assert "metrics" in result - assert "guardrail_trace" in result - assert "stop_reason" in result - - -@patch.object(AmazonBedrockConverseGenerator, "client") -def test_run_non_streaming(mock_client, generator): - mock_response = { - "output": {"message": {"role": "assistant", "content": [{"text": "Hello, how can I help you?"}]}}, - "usage": {"inputTokens": 10, "outputTokens": 20}, - "metrics": {"firstByteLatency": 0.5}, +from haystack_integrations.components.generators.amazon_bedrock import ( + ConverseMessage, + AmazonBedrockConverseGenerator, + MODEL_CAPABILITIES, + ToolConfig, +) +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockConfigurationError, +) + + +def get_current_weather(location: str, unit: str = "celsius") -> str: + """Get the current weather in a given location""" + # This is a mock function, replace with actual API call + return f"The weather in {location} is 22 degrees {unit}." + + +def get_current_time(timezone: str) -> str: + """Get the current time in a given timezone""" + # This is a mock function, replace with actual time lookup + return f"The current time in {timezone} is 14:30." + + +def test_to_dict(mock_boto3_session): + """ + Test that the to_dict method returns the correct dictionary without aws credentials + """ + tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) + + generator = AmazonBedrockConverseGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", + inference_config={ + "temperature": 0.1, + "maxTokens": 256, + "topP": 0.1, + "stopSequences": ["\\n"], + }, + tool_config=tool_config, + ) + + expected_dict = { + "type": "haystack_integrations.components.generators.amazon_bedrock.converse.converse_generator.AmazonBedrockConverseGenerator", + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "inference_config": { + "temperature": 0.1, + "maxTokens": 256, + "topP": 0.1, + "stopSequences": ["\\n"], + }, + "tool_config": tool_config.to_dict(), + "streaming_callback": None, + "system_prompt": None, + }, } - mock_client.converse.return_value = mock_response - - messages = [ConverseMessage.from_user(["Hello"])] - - result = generator.run(messages) - - mock_client.converse.assert_called_once() - assert isinstance(result["message"], ConverseMessage) - assert result["usage"] == mock_response["usage"] - assert result["metrics"] == mock_response["metrics"] - -def test_run_invalid_messages(generator): - with pytest.raises(ValueError): - generator.run(["invalid message"]) - - -def test_to_dict(generator): - serialized = generator.to_dict() - assert "model" in serialized - assert serialized["model"] == generator.model - - -def test_from_dict(): - data = { - "init_parameters": {"model": "anthropic.claude-3-haiku-20240307-v1:0", "streaming_callback": "builtins.print"} - } - deserialized = AmazonBedrockConverseGenerator.from_dict(data) - assert deserialized.model == "anthropic.claude-3-haiku-20240307-v1:0" - assert deserialized.streaming_callback == print + assert generator.to_dict() == expected_dict + + +def test_from_dict(mock_boto3_session): + """ + Test that the from_dict method returns the correct object + """ + tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) + + generator = AmazonBedrockConverseGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.amazon_bedrock.converse.converse_generator.AmazonBedrockConverseGenerator", + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", + "inference_config": { + "temperature": 0.1, + "maxTokens": 256, + "topP": 0.1, + "stopSequences": ["\\n"], + }, + "tool_config": tool_config.to_dict(), + "streaming_callback": None, + "system_prompt": None, + }, + } + ) + + assert generator.inference_config["temperature"] == 0.1 + assert generator.inference_config["maxTokens"] == 256 + assert generator.inference_config["topP"] == 0.1 + assert generator.inference_config["stopSequences"] == ["\\n"] + assert generator.tool_config.to_dict() == tool_config.to_dict() + assert generator.model == "anthropic.claude-3-5-sonnet-20240620-v1:0" From 75139dd2d1c79bfffb1227a10fe713001c204da8 Mon Sep 17 00:00:00 2001 From: FloRul Date: Sun, 1 Sep 2024 22:28:41 -0400 Subject: [PATCH 19/35] more tests --- .../converse/converse_generator.py | 33 ++-- .../amazon_bedrock/converse/utils.py | 33 +++- .../tests/test_converse_generator.py | 180 ++++++++++++++++++ 3 files changed, 227 insertions(+), 19 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 3704bd81a..770a46443 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -10,7 +10,7 @@ from haystack import component, default_from_dict, default_to_dict from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable -from .utils import ConverseMessage, ConverseStreamingChunk, ImageBlock, ToolConfig, get_stream_message +from .utils import ContentBlock, ConverseMessage, ConverseStreamingChunk, ImageBlock, ToolConfig, get_stream_message from haystack_integrations.common.amazon_bedrock.errors import ( AmazonBedrockConfigurationError, @@ -111,7 +111,12 @@ def __init__( self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name - # create the AWS session and client + self.inference_config = inference_config + self.tool_config = tool_config + self.streaming_callback = streaming_callback + self.system_prompt = system_prompt + self.model_capabilities = self._get_model_capabilities(model) # create the AWS session and client + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: return secret.resolve_value() if secret else None @@ -124,13 +129,6 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: aws_profile_name=resolve_secret(aws_profile_name), ) self.client = session.client("bedrock-runtime") - - self.inference_config = inference_config - self.tool_config = tool_config - self.streaming_callback = streaming_callback - self.system_prompt = system_prompt - self.model_capabilities = self._get_model_capabilities(model) - except Exception as exception: msg = ( "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " @@ -161,6 +159,7 @@ def run( ): streaming_callback = streaming_callback or self.streaming_callback system_prompt = system_prompt or self.system_prompt + tool_config = tool_config or self.tool_config if not ( isinstance(messages, list) @@ -178,10 +177,15 @@ def run( system_prompt = None if ModelCapability.VISION not in self.model_capabilities: - for msg in messages: - msg.content.content = [item for item in msg.content.content if not isinstance(item, ImageBlock)] - if any(isinstance(item, ImageBlock) for msg in messages for item in msg.content.content): - logger.warning(f"The model {self.model} does not support vision. Image content has been removed.") + messages = [ + ConverseMessage( + role=msg.role, + content=ContentBlock( + content=[item for item in msg.content.content if not isinstance(item, ImageBlock)] + ), + ) + for msg in messages + ] if ModelCapability.DOCUMENT_CHAT not in self.model_capabilities: logger.warning( @@ -276,6 +280,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockConverseGenerator": serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + tool_config = data.get("init_parameters", {}).get("tool_config") + if tool_config: + data["init_parameters"]["tool_config"] = ToolConfig.from_dict(tool_config) deserialize_secrets_inplace( data["init_parameters"], [ diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 808d0d37a..257af1c9b 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -256,14 +256,14 @@ class StreamEvent: data: Dict[str, Any] -def parse_event(event: Dict[str, Any]) -> StreamEvent: +def _parse_event(event: Dict[str, Any]) -> StreamEvent: for key in ["contentBlockStart", "contentBlockDelta", "contentBlockStop", "messageStop", "messageStart"]: if key in event: return StreamEvent(type=key, data=event[key]) return StreamEvent(type="metadata", data=event.get("metadata", {})) -def handle_content_block_start(event: StreamEvent, current_index: int) -> Tuple[int, Union[str, ToolUseBlock]]: +def _handle_content_block_start(event: StreamEvent, current_index: int) -> Tuple[int, Union[str, ToolUseBlock]]: new_index = event.data.get("contentBlockIndex", current_index + 1) start_of_tool_use = event.data.get("start") if start_of_tool_use: @@ -275,7 +275,7 @@ def handle_content_block_start(event: StreamEvent, current_index: int) -> Tuple[ return new_index, "" -def handle_content_block_delta( +def _handle_content_block_delta( event: StreamEvent, current_block: Union[str, ToolUseBlock], current_tool_use_input_str: str ) -> Tuple[Union[str, ToolUseBlock], str]: delta = event.data.get("delta", {}) @@ -300,6 +300,27 @@ def get_stream_message( stream: EventStream, streaming_callback: Callable[[ConverseStreamingChunk], None], ) -> Tuple[ConverseMessage, Dict[str, Any]]: + """ + Processes a stream of messages and returns a ConverseMessage and the associated metadata. + + The stream is expected to contain the following events: + + - contentBlockStart: Indicates the start of a content block. + - contentBlockDelta: Indicates a change to the content block. + - contentBlockStop: Indicates the end of a content block. + - messageStop: Indicates the end of a message. + - metadata: Indicates metadata about the message. + + The function processes each event in the stream and returns a ConverseMessage and the associated metadata. + The ConverseMessage will contain the content of the message, and the metadata will contain the stop reason and any other metadata from the stream. + + The function will also call the streaming_callback function with a ConverseStreamingChunk for each event in the stream. + The ConverseStreamingChunk will contain the content and metadata from the event. + + :param stream: The stream of messages to process. + :param streaming_callback: The callback function to call with each ConverseStreamingChunk. + :return: A tuple containing the ConverseMessage and the associated metadata. + """ current_block: Union[str, ToolUseBlock] = "" current_tool_use_input_str: str = "" latest_metadata: Dict[str, Any] = {} @@ -308,7 +329,7 @@ def get_stream_message( try: for raw_event in stream: - event = parse_event(raw_event) + event = _parse_event(raw_event) if event.type == "contentBlockStart": if current_block: @@ -316,10 +337,10 @@ def get_stream_message( streamed_contents[-1] += current_block else: streamed_contents.append(current_block) - current_index, current_block = handle_content_block_start(event, current_index) + current_index, current_block = _handle_content_block_start(event, current_index) elif event.type == "contentBlockDelta": - new_block, new_input_str = handle_content_block_delta(event, current_block, current_tool_use_input_str) + new_block, new_input_str = _handle_content_block_delta(event, current_block, current_tool_use_input_str) if isinstance(new_block, ToolUseBlock) and new_block != current_block: if current_block: if ( diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index 4f07f8373..dcda8a90c 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -1,5 +1,6 @@ from unittest.mock import Mock, patch, MagicMock import pytest +from botocore.exceptions import ClientError from haystack_integrations.components.generators.amazon_bedrock import ( ConverseMessage, @@ -9,6 +10,14 @@ ) from haystack_integrations.common.amazon_bedrock.errors import ( AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, +) +from haystack_integrations.components.generators.amazon_bedrock.converse.utils import ( + ConverseRole, + ImageBlock, + ImageSource, + ToolResultBlock, + ToolUseBlock, ) @@ -100,3 +109,174 @@ def test_from_dict(mock_boto3_session): assert generator.inference_config["stopSequences"] == ["\\n"] assert generator.tool_config.to_dict() == tool_config.to_dict() assert generator.model == "anthropic.claude-3-5-sonnet-20240620-v1:0" + + +def test_default_constructor(mock_boto3_session, set_env_variables): + """ + Test that the default constructor sets the correct values + """ + + layer = AmazonBedrockConverseGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", + ) + + assert layer.model == "anthropic.claude-3-5-sonnet-20240620-v1:0" + assert layer.inference_config is None + assert layer.tool_config is None + assert layer.streaming_callback is None + assert layer.system_prompt is None + + # assert mocked boto3 client called exactly once + mock_boto3_session.assert_called_once() + + # assert mocked boto3 client was called with the correct parameters + mock_boto3_session.assert_called_with( + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + profile_name="some_fake_profile", + region_name="fake_region", + ) + + +def test_constructor_with_empty_model(): + """ + Test that the constructor raises an error when the model is empty + """ + with pytest.raises(ValueError, match="cannot be None or empty string"): + AmazonBedrockConverseGenerator(model="") + + +def test_get_model_capabilities(): + generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") + assert generator.model_capabilities == MODEL_CAPABILITIES["anthropic.claude-3.*"] + + generator = AmazonBedrockConverseGenerator(model="ai21.j2-ultra-instruct-v1") + assert generator.model_capabilities == MODEL_CAPABILITIES["ai21.j2-.*-instruct"] + + with pytest.raises(ValueError, match="Unsupported model"): + AmazonBedrockConverseGenerator(model="unsupported.model-v1") + + +@patch("boto3.Session") +def test_run_with_different_message_types(mock_session): + mock_client = Mock() + mock_session.return_value.client.return_value = mock_client + mock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "Hello, how can I help you?"}]}}, + "usage": {"inputTokens": 10, "outputTokens": 20}, + "metrics": {"timeToFirstToken": 0.5}, + } + + generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") + + messages = [ + ConverseMessage.from_user(["What's the weather like?"]), + ConverseMessage.from_user([ImageBlock(format="png", source=ImageSource(bytes=b"fake_image_data"))]), + ] + + result = generator.run(messages) + + assert result["message"].role == ConverseRole.ASSISTANT + assert result["message"].content.content[0] == "Hello, how can I help you?" + assert result["usage"] == {"inputTokens": 10, "outputTokens": 20} + assert result["metrics"] == {"timeToFirstToken": 0.5} + + # Check the actual content sent to the API + mock_client.converse.assert_called_once() + call_args = mock_client.converse.call_args[1] + assert len(call_args["messages"]) == 2 + assert call_args["messages"][0]["content"] == [{"text": "What's the weather like?"}] + print(f"Actual content of second message: {call_args['messages'][1]['content']}") + # Depending on the actual behavior, you might need to adjust the following assertion: + assert call_args["messages"][1]["content"] == [] # or whatever the actual behavior is + + +from botocore.stub import Stubber + + +from unittest.mock import Mock, patch +import pytest +from botocore.exceptions import ClientError + + +def test_streaming(): + generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") + + mock_stream = Mock() + mock_stream.__iter__ = Mock( + return_value=iter( + [ + {'contentBlockStart': {'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': 'Hello'}}}, + {'contentBlockDelta': {'delta': {'text': ', how can I help you?'}}}, + {'contentBlockStop': {}}, + {'messageStop': {'stopReason': 'endOfResponse'}}, + ] + ) + ) + + generator.client.converse_stream = Mock(return_value={'stream': mock_stream}) + + chunks = [] + result = generator.run([ConverseMessage.from_user(["Hi"])], streaming_callback=lambda chunk: chunks.append(chunk)) + + assert len(chunks) == 5 + assert chunks[0].type == 'contentBlockStart' + assert chunks[1].content == 'Hello' + assert chunks[2].content == ', how can I help you?' + assert chunks[3].type == 'contentBlockStop' + assert chunks[4].type == 'messageStop' + + assert result['message'].content.content[0] == 'Hello, how can I help you?' + assert result['stop_reason'] == 'endOfResponse' + + +def test_client_error_handling(): + generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") + generator.client.converse = Mock( + side_effect=ClientError( + error_response={"Error": {"Code": "ThrottlingException", "Message": "Rate exceeded"}}, + operation_name="Converse", + ) + ) + + with pytest.raises(AmazonBedrockInferenceError, match="Could not run inference on Amazon Bedrock model"): + generator.run([ConverseMessage.from_user(["Hi"])]) + + +def test_tool_usage(): + tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) + generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", tool_config=tool_config) + + mock_response = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "get_current_weather", + "input": {"location": "London", "unit": "celsius"}, + } + }, + { + "toolResult": { + "toolUseId": "123", + "content": [{"text": "The weather in London is 22 degrees celsius."}], + } + }, + {"text": "Based on the weather information, it's a nice day in London."}, + ], + } + } + } + generator.client.converse = Mock(return_value=mock_response) + + result = generator.run([ConverseMessage.from_user(["What's the weather in London?"])]) + + assert len(result["message"].content.content) == 3 + assert isinstance(result["message"].content.content[0], ToolUseBlock) + assert isinstance(result["message"].content.content[1], ToolResultBlock) + assert result["message"].content.content[2] == "Based on the weather information, it's a nice day in London." From 8b9d0c193771149545bcba7f1fc7f93bd8460f54 Mon Sep 17 00:00:00 2001 From: FloRul Date: Thu, 5 Sep 2024 01:06:49 -0400 Subject: [PATCH 20/35] test consolidation - fix streaming chunk callback --- .../examples/converse_generator_example.py | 6 +- .../converse/converse_generator.py | 2 + .../amazon_bedrock/converse/utils.py | 11 +- .../tests/test_converse_generator.py | 146 ++++++++++++------ 4 files changed, 105 insertions(+), 60 deletions(-) diff --git a/integrations/amazon_bedrock/examples/converse_generator_example.py b/integrations/amazon_bedrock/examples/converse_generator_example.py index 8fa8d2181..b1a506bc1 100644 --- a/integrations/amazon_bedrock/examples/converse_generator_example.py +++ b/integrations/amazon_bedrock/examples/converse_generator_example.py @@ -1,6 +1,3 @@ - - - from haystack import Pipeline from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockConverseGenerator from haystack_integrations.components.generators.amazon_bedrock.converse.utils import ConverseMessage,ToolConfig @@ -21,7 +18,7 @@ def get_current_time(timezone: str) -> str: def main(): generator = AmazonBedrockConverseGenerator( model="anthropic.claude-3-5-sonnet-20240620-v1:0", - streaming_callback=print, + # streaming_callback=print, ) # Create ToolConfig from functions @@ -53,7 +50,6 @@ def main(): }, }, ) - print("\nPipeline Result:") print(result) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 770a46443..ef978c5c3 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -222,6 +222,8 @@ def run( else: response = self.client.converse(**request_kwargs) output = response.get("output") + # TODO: Delete + print(output) if output is None: raise KeyError("Response does not contain 'output'") message = output.get("message") diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 257af1c9b..22be5c29c 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -298,7 +298,7 @@ def _handle_content_block_delta( def get_stream_message( stream: EventStream, - streaming_callback: Callable[[ConverseStreamingChunk], None], + streaming_callback: Callable[[StreamEvent], None], ) -> Tuple[ConverseMessage, Dict[str, Any]]: """ Processes a stream of messages and returns a ConverseMessage and the associated metadata. @@ -371,14 +371,7 @@ def get_stream_message( latest_metadata["stopReason"] = event.data.get("stopReason") latest_metadata.update(event.data if event.type == "metadata" else {}) - - streaming_chunk = ConverseStreamingChunk( - content=current_block, - metadata=latest_metadata, - index=current_index, - type=event.type, - ) - streaming_callback(streaming_chunk) + streaming_callback(event) except Exception as e: logging.error(f"Error processing stream: {e!s}") diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index dcda8a90c..3b106a817 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -1,4 +1,5 @@ -from unittest.mock import Mock, patch, MagicMock +import json +from unittest.mock import Mock, patch, MagicMock import pytest from botocore.exceptions import ClientError @@ -16,6 +17,7 @@ ConverseRole, ImageBlock, ImageSource, + StreamEvent, ToolResultBlock, ToolUseBlock, ) @@ -187,7 +189,7 @@ def test_run_with_different_message_types(mock_session): call_args = mock_client.converse.call_args[1] assert len(call_args["messages"]) == 2 assert call_args["messages"][0]["content"] == [{"text": "What's the weather like?"}] - print(f"Actual content of second message: {call_args['messages'][1]['content']}") + print(f"Actual content of second message: {call_args["messages'][1]['content"]}") # Depending on the actual behavior, you might need to adjust the following assertion: assert call_args["messages"][1]["content"] == [] # or whatever the actual behavior is @@ -203,33 +205,96 @@ def test_run_with_different_message_types(mock_session): def test_streaming(): generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") + mocked_events = [ + {'messageStart': {'role': 'assistant'}}, + {'contentBlockDelta': {'delta': {'text': 'To'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' answer'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' your questions'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ','}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': " I'll"}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' need to'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' use'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' two'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' different functions'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ':'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' one'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' to check'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' the weather'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' in Paris and another'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' to get the current'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' time in New York'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': '.'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' Let'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' me'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' fetch'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' that'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' information for'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' you.'}, 'contentBlockIndex': 0}}, + {'contentBlockStop': {'contentBlockIndex': 0}}, + { + 'contentBlockStart': { + 'start': {'toolUse': {'toolUseId': 'tooluse_5Uu9EPSjQxiSsmc5Ex5MJg', 'name': 'get_current_weather'}}, + 'contentBlockIndex': 1, + } + }, + {'contentBlockDelta': {'delta': {'toolUse': {'input': ''}}, 'contentBlockIndex': 1}}, + {'contentBlockDelta': {'delta': {'toolUse': {'input': '{"loc'}}, 'contentBlockIndex': 1}}, + {'contentBlockDelta': {'delta': {'toolUse': {'input': 'ation":'}}, 'contentBlockIndex': 1}}, + {'contentBlockDelta': {'delta': {'toolUse': {'input': ' "Paris"'}}, 'contentBlockIndex': 1}}, + {'contentBlockDelta': {'delta': {'toolUse': {'input': ', "u'}}, 'contentBlockIndex': 1}}, + {'contentBlockDelta': {'delta': {'toolUse': {'input': 'nit": "ce'}}, 'contentBlockIndex': 1}}, + {'contentBlockDelta': {'delta': {'toolUse': {'input': 'lsius'}}, 'contentBlockIndex': 1}}, + {'contentBlockDelta': {'delta': {'toolUse': {'input': '"}'}}, 'contentBlockIndex': 1}}, + {'contentBlockStop': {'contentBlockIndex': 1}}, + { + 'contentBlockStart': { + 'start': {'toolUse': {'toolUseId': 'tooluse_cbK-e15KTFqZHtwpBJ0kzg', 'name': 'get_current_time'}}, + 'contentBlockIndex': 2, + } + }, + {'contentBlockDelta': {'delta': {'toolUse': {'input': ''}}, 'contentBlockIndex': 2}}, + {'contentBlockDelta': {'delta': {'toolUse': {'input': '{"timezon'}}, 'contentBlockIndex': 2}}, + {'contentBlockDelta': {'delta': {'toolUse': {'input': 'e"'}}, 'contentBlockIndex': 2}}, + {'contentBlockDelta': {'delta': {'toolUse': {'input': ': "A'}}, 'contentBlockIndex': 2}}, + {'contentBlockDelta': {'delta': {'toolUse': {'input': 'meric'}}, 'contentBlockIndex': 2}}, + {'contentBlockDelta': {'delta': {'toolUse': {'input': 'a/New'}}, 'contentBlockIndex': 2}}, + {'contentBlockDelta': {'delta': {'toolUse': {'input': '_York"}'}}, 'contentBlockIndex': 2}}, + {'contentBlockStop': {'contentBlockIndex': 2}}, + {'messageStop': {'stopReason': 'tool_use'}}, + { + 'metadata': { + 'usage': {'inputTokens': 446, 'outputTokens': 118, 'totalTokens': 564}, + 'metrics': {'latencyMs': 3930}, + } + }, + ] + mock_stream = Mock() - mock_stream.__iter__ = Mock( - return_value=iter( - [ - {'contentBlockStart': {'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': 'Hello'}}}, - {'contentBlockDelta': {'delta': {'text': ', how can I help you?'}}}, - {'contentBlockStop': {}}, - {'messageStop': {'stopReason': 'endOfResponse'}}, - ] - ) - ) + mock_stream.__iter__ = Mock(return_value=iter(mocked_events)) generator.client.converse_stream = Mock(return_value={'stream': mock_stream}) chunks = [] - result = generator.run([ConverseMessage.from_user(["Hi"])], streaming_callback=lambda chunk: chunks.append(chunk)) + result = generator.run( + [ConverseMessage.from_user(["What's the weather like in Paris and what time is it in New York?"])], + streaming_callback=lambda chunk: chunks.append(chunk), + ) - assert len(chunks) == 5 - assert chunks[0].type == 'contentBlockStart' - assert chunks[1].content == 'Hello' - assert chunks[2].content == ', how can I help you?' - assert chunks[3].type == 'contentBlockStop' - assert chunks[4].type == 'messageStop' + assert len(chunks) == len(mocked_events) + assert result["message"].content.content[0] == "To answer your questions, I'll need to use two different functions: one to check the weather in Paris and another to get the current time in New York. Let me fetch that information for you." + assert len(result["message"].content.content) == 3 + assert result["stop_reason"] == 'tool_use' + + assert result["message"].role == ConverseRole.ASSISTANT - assert result['message'].content.content[0] == 'Hello, how can I help you?' - assert result['stop_reason'] == 'endOfResponse' + assert result["usage"] == {'inputTokens': 446, 'outputTokens': 118, 'totalTokens': 564} + assert result["metrics"] == {'latencyMs': 3930} + + assert result["message"].content.content[1].name == "get_current_weather" + assert result["message"].content.content[2].name == "get_current_time" + + assert json.dumps(result["message"].content.content[1].input) == """{"location": "Paris", "unit": "celsius"}""" + assert json.dumps(result["message"].content.content[2].input) == """{"timezone": "America/New_York"}""" def test_client_error_handling(): @@ -250,33 +315,22 @@ def test_tool_usage(): generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", tool_config=tool_config) mock_response = { - "output": { - "message": { - "role": "assistant", - "content": [ - { - "toolUse": { - "toolUseId": "123", - "name": "get_current_weather", - "input": {"location": "London", "unit": "celsius"}, - } - }, - { - "toolResult": { - "toolUseId": "123", - "content": [{"text": "The weather in London is 22 degrees celsius."}], - } - }, - {"text": "Based on the weather information, it's a nice day in London."}, - ], - } - } + "output": {'message': {'role': 'assistant', 'content': [{'text': "Certainly! I'd be happy to help you with the weather in Paris and the current time in New York. To get this information, I'll need to use two different tools. Let me fetch that data for you."}, {'toolUse': {'toolUseId': 'tooluse_-Tp78_OeSq-1DSsP0B__TA', 'name': 'get_current_weather', 'input': {'location': 'Paris', 'unit': 'celsius'}}}, {'toolUse': {'toolUseId': 'tooluse_gdYvqeiGTme7toWoV4sSKw', 'name': 'get_current_time', 'input': {'timezone': 'America/New_York'},},},],},}, + "stop_reason": "tool_use", } generator.client.converse = Mock(return_value=mock_response) result = generator.run([ConverseMessage.from_user(["What's the weather in London?"])]) assert len(result["message"].content.content) == 3 - assert isinstance(result["message"].content.content[0], ToolUseBlock) - assert isinstance(result["message"].content.content[1], ToolResultBlock) - assert result["message"].content.content[2] == "Based on the weather information, it's a nice day in London." + assert isinstance(result["message"].content.content[0], str) + assert isinstance(result["message"].content.content[1], ToolUseBlock) + assert isinstance(result["message"].content.content[2], ToolUseBlock) + assert result["stop_reason"] == "tool_use" + assert result["message"].role == ConverseRole.ASSISTANT + assert result["message"].content.content[0] == "Certainly! I'd be happy to help you with the weather in Paris and the current time in New York. To get this information, I'll need to use two different tools. Let me fetch that data for you." + assert result["message"].content.content[1].name == "get_current_weather" + assert result["message"].content.content[2].name == "get_current_time" + assert json.dumps(result["message"].content.content[1].input) == """{"location": "Paris", "unit": "celsius"}""" + assert json.dumps(result["message"].content.content[2].input) == """{"timezone": "America/New_York"}""" + From 4beb3be35245f41ee4cd48057dab502780b939ea Mon Sep 17 00:00:00 2001 From: FloRul Date: Fri, 6 Sep 2024 22:53:38 -0400 Subject: [PATCH 21/35] update test, examples and model capabilities --- .../examples/converse_generator_example.py | 5 +- .../generators/amazon_bedrock/__init__.py | 4 +- .../amazon_bedrock/converse/capabilities.py | 23 +- .../converse/converse_generator.py | 39 ++-- .../amazon_bedrock/converse/utils.py | 37 +++- .../tests/test_converse_generator.py | 209 +++++++++++------- 6 files changed, 193 insertions(+), 124 deletions(-) diff --git a/integrations/amazon_bedrock/examples/converse_generator_example.py b/integrations/amazon_bedrock/examples/converse_generator_example.py index b1a506bc1..034829610 100644 --- a/integrations/amazon_bedrock/examples/converse_generator_example.py +++ b/integrations/amazon_bedrock/examples/converse_generator_example.py @@ -1,6 +1,7 @@ from haystack import Pipeline + from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockConverseGenerator -from haystack_integrations.components.generators.amazon_bedrock.converse.utils import ConverseMessage,ToolConfig +from haystack_integrations.components.generators.amazon_bedrock.converse.utils import ConverseMessage, ToolConfig def get_current_weather(location: str, unit: str = "celsius") -> str: @@ -54,5 +55,5 @@ def main(): print(result) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py index f74e8a7ce..88139497c 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py @@ -2,10 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 from .chat.chat_generator import AmazonBedrockChatGenerator +from .converse.capabilities import MODEL_CAPABILITIES from .converse.converse_generator import AmazonBedrockConverseGenerator -from .generator import AmazonBedrockGenerator from .converse.utils import ConverseMessage, ToolConfig -from .converse.capabilities import MODEL_CAPABILITIES +from .generator import AmazonBedrockGenerator __all__ = [ "AmazonBedrockGenerator", diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py index 4c69d94c5..d5dc251a2 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py @@ -14,24 +14,22 @@ class ModelCapability(Enum): # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html MODEL_CAPABILITIES = { - "ai21.j2-.*-instruct": { + "ai21.jamba-instruct-.*": { ModelCapability.CONVERSE, ModelCapability.CONVERSE_STREAM, ModelCapability.SYSTEM_PROMPTS, }, - "ai21.j2-.*-text": {ModelCapability.CONVERSE, ModelCapability.GUARDRAILS}, - "amazon.titan-.*": { + "ai21.j2-.*": { ModelCapability.CONVERSE, - ModelCapability.CONVERSE_STREAM, - ModelCapability.DOCUMENT_CHAT, ModelCapability.GUARDRAILS, }, - "amazon.titan-text-express-v1": { + "amazon.titan-text-.*": { ModelCapability.CONVERSE, ModelCapability.CONVERSE_STREAM, + ModelCapability.DOCUMENT_CHAT, ModelCapability.GUARDRAILS, }, - "anthropic.claude-2.*": { + "anthropic.claude-v2.*": { ModelCapability.CONVERSE, ModelCapability.CONVERSE_STREAM, ModelCapability.SYSTEM_PROMPTS, @@ -48,8 +46,15 @@ class ModelCapability(Enum): ModelCapability.STREAMING_TOOL_USE, ModelCapability.GUARDRAILS, }, - "cohere.command-text.*": {ModelCapability.CONVERSE, ModelCapability.DOCUMENT_CHAT, ModelCapability.GUARDRAILS}, - "cohere.command-light.*": {ModelCapability.CONVERSE, ModelCapability.GUARDRAILS}, + "cohere.command-text.*": { + ModelCapability.CONVERSE, + ModelCapability.DOCUMENT_CHAT, + ModelCapability.GUARDRAILS, + }, + "cohere.command-light.*": { + ModelCapability.CONVERSE, + ModelCapability.GUARDRAILS, + }, "cohere.command-r.*": { ModelCapability.CONVERSE, ModelCapability.CONVERSE_STREAM, diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index ef978c5c3..e5b9ae19f 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -3,14 +3,9 @@ from typing import Any, Callable, Dict, List, Optional, Set from botocore.exceptions import ClientError -from .capabilities import ( - MODEL_CAPABILITIES, - ModelCapability, -) from haystack import component, default_from_dict, default_to_dict from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.callable_serialization import deserialize_callable, serialize_callable -from .utils import ContentBlock, ConverseMessage, ConverseStreamingChunk, ImageBlock, ToolConfig, get_stream_message from haystack_integrations.common.amazon_bedrock.errors import ( AmazonBedrockConfigurationError, @@ -18,6 +13,12 @@ ) from haystack_integrations.common.amazon_bedrock.utils import get_aws_session +from .capabilities import ( + MODEL_CAPABILITIES, + ModelCapability, +) +from .utils import ContentBlock, ConverseMessage, ConverseStreamingChunk, ImageBlock, ToolConfig, get_stream_message + logger = logging.getLogger(__name__) @@ -140,7 +141,8 @@ def _get_model_capabilities(self, model: str) -> Set[ModelCapability]: for pattern, capabilities in MODEL_CAPABILITIES.items(): if re.match(pattern, model): return capabilities - raise ValueError(f"Unsupported model: {model}") + unsupported_model_error = ValueError(f"Unsupported model: {model}") + raise unsupported_model_error @component.output_types( message=ConverseMessage, @@ -153,7 +155,7 @@ def run( self, messages: List[ConverseMessage], streaming_callback: Optional[Callable[[ConverseStreamingChunk], None]] = None, - inference_config: Dict[str, Any] = {}, + inference_config: Optional[Dict[str, Any]] = None, tool_config: Optional[ToolConfig] = None, system_prompt: Optional[List[Dict[str, Any]]] = None, ): @@ -198,12 +200,13 @@ def run( if ModelCapability.STREAMING_TOOL_USE not in self.model_capabilities and streaming_callback and tool_config: logger.warning( - f"The model {self.model} does not support streaming tool use. Streaming will be disabled for tool calls." + f"The model {self.model} does not support streaming tool use. " + "Streaming will be disabled for tool calls." ) request_kwargs = { "modelId": self.model, - "inferenceConfig": inference_config, + "inferenceConfig": inference_config or self.inference_config, "messages": [message.to_dict() for message in messages], } @@ -216,21 +219,21 @@ def run( try: if streaming_callback and ModelCapability.CONVERSE_STREAM in self.model_capabilities: - response = self.client.converse_stream(**request_kwargs) - response_stream = response.get("stream") + converse_response = self.client.converse_stream(**request_kwargs) + response_stream = converse_response.get("stream") message, metadata = get_stream_message(stream=response_stream, streaming_callback=streaming_callback) else: - response = self.client.converse(**request_kwargs) - output = response.get("output") - # TODO: Delete - print(output) + converse_response = self.client.converse(**request_kwargs) + output = converse_response.get("output") if output is None: - raise KeyError("Response does not contain 'output'") + response_output_missing_error = "Response does not contain 'output'" + raise KeyError(response_output_missing_error) message = output.get("message") if message is None: - raise KeyError("Response 'output' does not contain 'message'") + response_output_missing_message_error = "Response 'output' does not contain 'message'" + raise KeyError(response_output_missing_message_error) message = ConverseMessage.from_dict(message) - metadata = response + metadata = converse_response return { "message": message, diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 22be5c29c..a5ed548b8 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -33,12 +33,14 @@ class ToolConfig: toolChoice: Optional[ToolChoice] = None def __post_init__(self): + msg = "Only one of 'auto', 'any', or 'tool' can be set in toolChoice" if self.toolChoice and sum(bool(v) for v in vars(self.toolChoice).values()) != 1: - raise ValueError("Only one of 'auto', 'any', or 'tool' can be set in toolChoice") + raise ValueError(msg) if self.toolChoice and self.toolChoice.tool: if "name" not in self.toolChoice.tool: - raise ValueError("'name' is required when 'tool' is specified in toolChoice") + msg = "'name' is required when 'tool' is specified in toolChoice" + raise ValueError(msg) @staticmethod def from_functions(functions: List[Callable]) -> "ToolConfig": @@ -76,7 +78,18 @@ def from_dict(cls, config: Dict) -> "ToolConfig": return cls(tools=tools, toolChoice=tool_choice) def to_dict(self) -> Dict[str, Any]: - result = {"tools": [{"toolSpec": asdict(tool.toolSpec)} for tool in self.tools]} + result = { + "tools": [ + { + "toolSpec": { + "name": tool.toolSpec.name, + "description": tool.toolSpec.description, + "inputSchema": tool.toolSpec.inputSchema, + } + } + for tool in self.tools + ] + } if self.toolChoice: tool_choice: Dict[str, Dict[str, Any]] = {} if self.toolChoice.auto: @@ -151,17 +164,19 @@ class ContentBlock: content: List[Union[DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, ToolUseBlock]] def __post_init__(self): + err_msg = "Content must be a list" if not isinstance(self.content, list): - raise ValueError("Content must be a list") + raise ValueError(err_msg) for item in self.content: if not isinstance( item, (DocumentBlock, GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, ToolUseBlock) ): - raise ValueError( + msg = ( f"Invalid content type: {type(item)}. Each item must be one of DocumentBlock, " "GuardrailConverseContentBlock, ImageBlock, str, ToolResultBlock, or ToolUseBlock" ) + raise ValueError(msg) @staticmethod def from_assistant(content: Sequence[Union[str, ToolUseBlock]]) -> "ContentBlock": @@ -183,7 +198,8 @@ def to_dict(self): elif isinstance(item, ToolUseBlock): res.append({"toolUse": asdict(item)}) else: - raise ValueError(f"Unsupported content type: {type(item)}") + msg = f"Unsupported content type: {type(item)}" + raise ValueError(msg) return res @@ -231,7 +247,8 @@ def from_dict(data: Dict[str, Any]) -> "ConverseMessage": elif "guardContent" in item: content_blocks.append(GuardrailConverseContentBlock(**item["guardContent"])) else: - raise ValueError(f"Unknown content type in message: {item}") + unknown_type = f"Unknown content type in message: {item}" + raise ValueError(unknown_type) return ConverseMessage(role, ContentBlock(content=content_blocks)) @@ -312,9 +329,11 @@ def get_stream_message( - metadata: Indicates metadata about the message. The function processes each event in the stream and returns a ConverseMessage and the associated metadata. - The ConverseMessage will contain the content of the message, and the metadata will contain the stop reason and any other metadata from the stream. + The ConverseMessage will contain the content of the message, + and the metadata will contain the stop reason and any other metadata from the stream. - The function will also call the streaming_callback function with a ConverseStreamingChunk for each event in the stream. + The function will also call the streaming_callback function + with a ConverseStreamingChunk for each event in the stream. The ConverseStreamingChunk will contain the content and metadata from the event. :param stream: The stream of messages to process. diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index 3b106a817..b4a4be18e 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -1,24 +1,22 @@ import json -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import Mock, patch + import pytest from botocore.exceptions import ClientError +from haystack_integrations.common.amazon_bedrock.errors import ( + AmazonBedrockInferenceError, +) from haystack_integrations.components.generators.amazon_bedrock import ( - ConverseMessage, - AmazonBedrockConverseGenerator, MODEL_CAPABILITIES, + AmazonBedrockConverseGenerator, + ConverseMessage, ToolConfig, ) -from haystack_integrations.common.amazon_bedrock.errors import ( - AmazonBedrockConfigurationError, - AmazonBedrockInferenceError, -) from haystack_integrations.components.generators.amazon_bedrock.converse.utils import ( ConverseRole, ImageBlock, ImageSource, - StreamEvent, - ToolResultBlock, ToolUseBlock, ) @@ -55,11 +53,31 @@ def test_to_dict(mock_boto3_session): expected_dict = { "type": "haystack_integrations.components.generators.amazon_bedrock.converse.converse_generator.AmazonBedrockConverseGenerator", "init_parameters": { - "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, - "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, - "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, - "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, - "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, "model": "anthropic.claude-3-5-sonnet-20240620-v1:0", "inference_config": { "temperature": 0.1, @@ -189,82 +207,74 @@ def test_run_with_different_message_types(mock_session): call_args = mock_client.converse.call_args[1] assert len(call_args["messages"]) == 2 assert call_args["messages"][0]["content"] == [{"text": "What's the weather like?"}] - print(f"Actual content of second message: {call_args["messages'][1]['content"]}") - # Depending on the actual behavior, you might need to adjust the following assertion: - assert call_args["messages"][1]["content"] == [] # or whatever the actual behavior is - -from botocore.stub import Stubber - -from unittest.mock import Mock, patch -import pytest -from botocore.exceptions import ClientError +from unittest.mock import patch def test_streaming(): generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") mocked_events = [ - {'messageStart': {'role': 'assistant'}}, - {'contentBlockDelta': {'delta': {'text': 'To'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' answer'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' your questions'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ','}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': " I'll"}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' need to'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' use'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' two'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' different functions'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ':'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' one'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' to check'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' the weather'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' in Paris and another'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' to get the current'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' time in New York'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': '.'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' Let'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' me'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' fetch'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' that'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' information for'}, 'contentBlockIndex': 0}}, - {'contentBlockDelta': {'delta': {'text': ' you.'}, 'contentBlockIndex': 0}}, - {'contentBlockStop': {'contentBlockIndex': 0}}, + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "To"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " answer"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " your questions"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": ","}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " I'll"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " need to"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " use"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " two"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " different functions"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": ":"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " one"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " to check"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " the weather"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " in Paris and another"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " to get the current"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " time in New York"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": "."}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " Let"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " me"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " fetch"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " that"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " information for"}, "contentBlockIndex": 0}}, + {"contentBlockDelta": {"delta": {"text": " you."}, "contentBlockIndex": 0}}, + {"contentBlockStop": {"contentBlockIndex": 0}}, { - 'contentBlockStart': { - 'start': {'toolUse': {'toolUseId': 'tooluse_5Uu9EPSjQxiSsmc5Ex5MJg', 'name': 'get_current_weather'}}, - 'contentBlockIndex': 1, + "contentBlockStart": { + "start": {"toolUse": {"toolUseId": "tooluse_5Uu9EPSjQxiSsmc5Ex5MJg", "name": "get_current_weather"}}, + "contentBlockIndex": 1, } }, - {'contentBlockDelta': {'delta': {'toolUse': {'input': ''}}, 'contentBlockIndex': 1}}, - {'contentBlockDelta': {'delta': {'toolUse': {'input': '{"loc'}}, 'contentBlockIndex': 1}}, - {'contentBlockDelta': {'delta': {'toolUse': {'input': 'ation":'}}, 'contentBlockIndex': 1}}, - {'contentBlockDelta': {'delta': {'toolUse': {'input': ' "Paris"'}}, 'contentBlockIndex': 1}}, - {'contentBlockDelta': {'delta': {'toolUse': {'input': ', "u'}}, 'contentBlockIndex': 1}}, - {'contentBlockDelta': {'delta': {'toolUse': {'input': 'nit": "ce'}}, 'contentBlockIndex': 1}}, - {'contentBlockDelta': {'delta': {'toolUse': {'input': 'lsius'}}, 'contentBlockIndex': 1}}, - {'contentBlockDelta': {'delta': {'toolUse': {'input': '"}'}}, 'contentBlockIndex': 1}}, - {'contentBlockStop': {'contentBlockIndex': 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"loc'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": 'ation":'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ' "Paris"'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ', "u'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": 'nit": "ce'}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "lsius"}}, "contentBlockIndex": 1}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '"}'}}, "contentBlockIndex": 1}}, + {"contentBlockStop": {"contentBlockIndex": 1}}, { - 'contentBlockStart': { - 'start': {'toolUse': {'toolUseId': 'tooluse_cbK-e15KTFqZHtwpBJ0kzg', 'name': 'get_current_time'}}, - 'contentBlockIndex': 2, + "contentBlockStart": { + "start": {"toolUse": {"toolUseId": "tooluse_cbK-e15KTFqZHtwpBJ0kzg", "name": "get_current_time"}}, + "contentBlockIndex": 2, } }, - {'contentBlockDelta': {'delta': {'toolUse': {'input': ''}}, 'contentBlockIndex': 2}}, - {'contentBlockDelta': {'delta': {'toolUse': {'input': '{"timezon'}}, 'contentBlockIndex': 2}}, - {'contentBlockDelta': {'delta': {'toolUse': {'input': 'e"'}}, 'contentBlockIndex': 2}}, - {'contentBlockDelta': {'delta': {'toolUse': {'input': ': "A'}}, 'contentBlockIndex': 2}}, - {'contentBlockDelta': {'delta': {'toolUse': {'input': 'meric'}}, 'contentBlockIndex': 2}}, - {'contentBlockDelta': {'delta': {'toolUse': {'input': 'a/New'}}, 'contentBlockIndex': 2}}, - {'contentBlockDelta': {'delta': {'toolUse': {'input': '_York"}'}}, 'contentBlockIndex': 2}}, - {'contentBlockStop': {'contentBlockIndex': 2}}, - {'messageStop': {'stopReason': 'tool_use'}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 2}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"timezon'}}, "contentBlockIndex": 2}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": 'e"'}}, "contentBlockIndex": 2}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ': "A'}}, "contentBlockIndex": 2}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "meric"}}, "contentBlockIndex": 2}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": "a/New"}}, "contentBlockIndex": 2}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '_York"}'}}, "contentBlockIndex": 2}}, + {"contentBlockStop": {"contentBlockIndex": 2}}, + {"messageStop": {"stopReason": "tool_use"}}, { - 'metadata': { - 'usage': {'inputTokens': 446, 'outputTokens': 118, 'totalTokens': 564}, - 'metrics': {'latencyMs': 3930}, + "metadata": { + "usage": {"inputTokens": 446, "outputTokens": 118, "totalTokens": 564}, + "metrics": {"latencyMs": 3930}, } }, ] @@ -272,7 +282,7 @@ def test_streaming(): mock_stream = Mock() mock_stream.__iter__ = Mock(return_value=iter(mocked_events)) - generator.client.converse_stream = Mock(return_value={'stream': mock_stream}) + generator.client.converse_stream = Mock(return_value={"stream": mock_stream}) chunks = [] result = generator.run( @@ -281,14 +291,17 @@ def test_streaming(): ) assert len(chunks) == len(mocked_events) - assert result["message"].content.content[0] == "To answer your questions, I'll need to use two different functions: one to check the weather in Paris and another to get the current time in New York. Let me fetch that information for you." + assert ( + result["message"].content.content[0] + == "To answer your questions, I'll need to use two different functions: one to check the weather in Paris and another to get the current time in New York. Let me fetch that information for you." + ) assert len(result["message"].content.content) == 3 - assert result["stop_reason"] == 'tool_use' + assert result["stop_reason"] == "tool_use" assert result["message"].role == ConverseRole.ASSISTANT - assert result["usage"] == {'inputTokens': 446, 'outputTokens': 118, 'totalTokens': 564} - assert result["metrics"] == {'latencyMs': 3930} + assert result["usage"] == {"inputTokens": 446, "outputTokens": 118, "totalTokens": 564} + assert result["metrics"] == {"latencyMs": 3930} assert result["message"].content.content[1].name == "get_current_weather" assert result["message"].content.content[2].name == "get_current_time" @@ -315,8 +328,31 @@ def test_tool_usage(): generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", tool_config=tool_config) mock_response = { - "output": {'message': {'role': 'assistant', 'content': [{'text': "Certainly! I'd be happy to help you with the weather in Paris and the current time in New York. To get this information, I'll need to use two different tools. Let me fetch that data for you."}, {'toolUse': {'toolUseId': 'tooluse_-Tp78_OeSq-1DSsP0B__TA', 'name': 'get_current_weather', 'input': {'location': 'Paris', 'unit': 'celsius'}}}, {'toolUse': {'toolUseId': 'tooluse_gdYvqeiGTme7toWoV4sSKw', 'name': 'get_current_time', 'input': {'timezone': 'America/New_York'},},},],},}, - "stop_reason": "tool_use", + "output": { + "message": { + "role": "assistant", + "content": [ + { + "text": "Certainly! I'd be happy to help you with the weather in Paris and the current time in New York. To get this information, I'll need to use two different tools. Let me fetch that data for you." + }, + { + "toolUse": { + "toolUseId": "tooluse_-Tp78_OeSq-1DSsP0B__TA", + "name": "get_current_weather", + "input": {"location": "Paris", "unit": "celsius"}, + } + }, + { + "toolUse": { + "toolUseId": "tooluse_gdYvqeiGTme7toWoV4sSKw", + "name": "get_current_time", + "input": {"timezone": "America/New_York"}, + }, + }, + ], + }, + }, + "stopReason": "tool_use", } generator.client.converse = Mock(return_value=mock_response) @@ -328,9 +364,14 @@ def test_tool_usage(): assert isinstance(result["message"].content.content[2], ToolUseBlock) assert result["stop_reason"] == "tool_use" assert result["message"].role == ConverseRole.ASSISTANT - assert result["message"].content.content[0] == "Certainly! I'd be happy to help you with the weather in Paris and the current time in New York. To get this information, I'll need to use two different tools. Let me fetch that data for you." + assert ( + result["message"].content.content[0] + == "Certainly! I'd be happy to help you with the weather in Paris and the current time in New York. To get this information, I'll need to use two different tools. Let me fetch that data for you." + ) assert result["message"].content.content[1].name == "get_current_weather" assert result["message"].content.content[2].name == "get_current_time" assert json.dumps(result["message"].content.content[1].input) == """{"location": "Paris", "unit": "celsius"}""" assert json.dumps(result["message"].content.content[2].input) == """{"timezone": "America/New_York"}""" - + assert result["message"].content.content[1].input["location"] == "Paris" + assert result["message"].content.content[1].input["unit"] == "celsius" + assert result["message"].content.content[2].input["timezone"] == "America/New_York" From abf79dd88cd324eef3cad1bc18e586c9006fd986 Mon Sep 17 00:00:00 2001 From: FloRul Date: Fri, 6 Sep 2024 23:01:37 -0400 Subject: [PATCH 22/35] clean up --- .../examples/converse_generator_example.py | 76 ++++++++----------- .../amazon_bedrock/converse/capabilities.py | 1 + .../converse/converse_generator.py | 53 +++++++++++-- 3 files changed, 79 insertions(+), 51 deletions(-) diff --git a/integrations/amazon_bedrock/examples/converse_generator_example.py b/integrations/amazon_bedrock/examples/converse_generator_example.py index 034829610..816db4988 100644 --- a/integrations/amazon_bedrock/examples/converse_generator_example.py +++ b/integrations/amazon_bedrock/examples/converse_generator_example.py @@ -5,55 +5,45 @@ def get_current_weather(location: str, unit: str = "celsius") -> str: - """Get the current weather in a given location""" - # This is a mock function, replace with actual API call return f"The weather in {location} is 22 degrees {unit}." def get_current_time(timezone: str) -> str: - """Get the current time in a given timezone""" - # This is a mock function, replace with actual time lookup return f"The current time in {timezone} is 14:30." -def main(): - generator = AmazonBedrockConverseGenerator( - model="anthropic.claude-3-5-sonnet-20240620-v1:0", - # streaming_callback=print, - ) - - # Create ToolConfig from functions - tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) - - # Convert ToolConfig to dict for use in the run method - tool_config_dict = tool_config.to_dict() - - print("Tool Config:") - print(tool_config_dict) - - pipeline = Pipeline() - pipeline.add_component("generator", generator) - - print("\nRunning pipeline with tools:") - result = pipeline.run( - data={ - "generator": { - "inference_config": { - "temperature": 0.1, - "maxTokens": 256, - "topP": 0.1, - "stopSequences": ["\\n"], - }, - "messages": [ - ConverseMessage.from_user(["What's the weather like in Paris and what time is it in New York?"]), - ], - "tool_config": tool_config_dict, - }, - }, - ) - print("\nPipeline Result:") - print(result) +generator = AmazonBedrockConverseGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", + # streaming_callback=print, +) + +# Create ToolConfig from functions +tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) + +tool_config_dict = tool_config.to_dict() +print("Tool Config:") +print(tool_config_dict) -if __name__ == "__main__": - main() +pipeline = Pipeline() +pipeline.add_component("generator", generator) + +print("\nRunning pipeline with tools:") +result = pipeline.run( + data={ + "generator": { + "inference_config": { + "temperature": 0.1, + "maxTokens": 256, + "topP": 0.1, + "stopSequences": ["\\n"], + }, + "messages": [ + ConverseMessage.from_user(["What's the weather like in Paris and what time is it in New York?"]), + ], + "tool_config": tool_config_dict, + }, + }, +) +print("\nPipeline Result:") +print(result) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py index d5dc251a2..3f5f40d56 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py @@ -13,6 +13,7 @@ class ModelCapability(Enum): # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html + MODEL_CAPABILITIES = { "ai21.jamba-instruct-.*": { ModelCapability.CONVERSE, diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index e5b9ae19f..7fc7809ed 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -34,18 +34,54 @@ class AmazonBedrockConverseGenerator: ### Usage example ```python - from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator - from haystack.dataclasses import ChatMessage - from haystack.components.generators.utils import print_streaming_chunk + from haystack import Pipeline - messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant, answer in German only"), - ChatMessage.from_user("What's Natural Language Processing?")] + from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockConverseGenerator + from haystack_integrations.components.generators.amazon_bedrock.converse.utils import ConverseMessage, ToolConfig - client = AmazonBedrockChatGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", - streaming_callback=print_streaming_chunk) - client.run(messages, generation_kwargs={"max_tokens": 512}) + def get_current_weather(location: str, unit: str = "celsius") -> str: + return f"The weather in {location} is 22 degrees {unit}." + + def get_current_time(timezone: str) -> str: + return f"The current time in {timezone} is 14:30." + + + generator = AmazonBedrockConverseGenerator( + model="anthropic.claude-3-5-sonnet-20240620-v1:0", + # streaming_callback=print, + ) + + # Create ToolConfig from functions + tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) + + # Convert ToolConfig to dict for use in the run method + tool_config_dict = tool_config.to_dict() + + print("Tool Config:") + print(tool_config_dict) + + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + result = pipeline.run( + data={ + "generator": { + "inference_config": { + "temperature": 0.1, + "maxTokens": 256, + "topP": 0.1, + "stopSequences": ["\\n"], + }, + "messages": [ + ConverseMessage.from_user(["What's the weather like in Paris and what time is it in New York?"]), + ], + "tool_config": tool_config_dict, + }, + }, + ) + print(result) ``` AmazonBedrockChatGenerator uses AWS for authentication. You can use the AWS CLI to authenticate through your IAM. @@ -101,6 +137,7 @@ def __init__( switches the streaming mode on. :param inference_config: A dictionary containing the inference configuration. The default value is None. :param tool_config: A dictionary containing the tool configuration. The default value is None. + :param system_prompt: A list of dictionaries containing the system prompt. The default value is None. """ if not model: msg = "'model' cannot be None or empty string" From 9a459af3076f77adb1ab5d08dd56cea974e6457f Mon Sep 17 00:00:00 2001 From: FloRul Date: Fri, 6 Sep 2024 23:11:33 -0400 Subject: [PATCH 23/35] prepare for PR --- integrations/amazon_bedrock/CHANGELOG.md | 6 ++++++ .../amazon_bedrock/tests/test_converse_generator.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 417c661fe..33cdca69c 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -2,6 +2,12 @@ ## [unreleased] +### 🚀 Features + +- New generator for the bedrock converse API (#977) + +## [unreleased] + ### 🐛 Bug Fixes - *(Bedrock)* Allow tools kwargs for AWS Bedrock Claude model (#976) diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index b4a4be18e..622174795 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -33,7 +33,7 @@ def get_current_time(timezone: str) -> str: return f"The current time in {timezone} is 14:30." -def test_to_dict(mock_boto3_session): +def test_to_dict(): """ Test that the to_dict method returns the correct dictionary without aws credentials """ @@ -94,7 +94,7 @@ def test_to_dict(mock_boto3_session): assert generator.to_dict() == expected_dict -def test_from_dict(mock_boto3_session): +def test_from_dict(): """ Test that the from_dict method returns the correct object """ From f5c59f8ad1e91d3c0438057927c76a4f264aed7d Mon Sep 17 00:00:00 2001 From: FloRul Date: Sat, 7 Sep 2024 00:26:01 -0400 Subject: [PATCH 24/35] Update test_converse_generator.py --- integrations/amazon_bedrock/tests/test_converse_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index 622174795..158af41b4 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -1,5 +1,5 @@ -import json from unittest.mock import Mock, patch +import json import pytest from botocore.exceptions import ClientError From cd63bcf4e45f7968126559f4cf8dbfc9c0d458ee Mon Sep 17 00:00:00 2001 From: FloRul Date: Sat, 7 Sep 2024 20:43:21 -0400 Subject: [PATCH 25/35] fixing encoding --- .../generators/amazon_bedrock/converse/capabilities.py | 2 +- .../components/generators/amazon_bedrock/converse/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py index 3f5f40d56..98dcd53b3 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/capabilities.py @@ -1,4 +1,4 @@ -from enum import Enum, auto +from enum import Enum, auto class ModelCapability(Enum): diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index a5ed548b8..67b826796 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -1,4 +1,4 @@ -import inspect +import inspect import json import logging from dataclasses import asdict, dataclass, field From 0af2b8959db436e76f9280dacfae4cec0c740d70 Mon Sep 17 00:00:00 2001 From: FloRul Date: Sat, 7 Sep 2024 21:19:37 -0400 Subject: [PATCH 26/35] fixing lint and tests --- .../examples/converse_generator_example.py | 5 -- .../amazon_bedrock/converse/utils.py | 64 ++++++++++--------- .../tests/test_converse_generator.py | 23 +++---- 3 files changed, 47 insertions(+), 45 deletions(-) diff --git a/integrations/amazon_bedrock/examples/converse_generator_example.py b/integrations/amazon_bedrock/examples/converse_generator_example.py index 816db4988..1f0dd6d73 100644 --- a/integrations/amazon_bedrock/examples/converse_generator_example.py +++ b/integrations/amazon_bedrock/examples/converse_generator_example.py @@ -22,13 +22,10 @@ def get_current_time(timezone: str) -> str: tool_config_dict = tool_config.to_dict() -print("Tool Config:") -print(tool_config_dict) pipeline = Pipeline() pipeline.add_component("generator", generator) -print("\nRunning pipeline with tools:") result = pipeline.run( data={ "generator": { @@ -45,5 +42,3 @@ def get_current_time(timezone: str) -> str: }, }, ) -print("\nPipeline Result:") -print(result) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 67b826796..23e07b05d 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -12,12 +12,12 @@ class ToolSpec: name: str description: Optional[str] = None - inputSchema: Dict[str, Dict] = field(default_factory=dict) + input_schema: Dict[str, Dict] = field(default_factory=dict) @dataclass class Tool: - toolSpec: ToolSpec + tool_spec: ToolSpec @dataclass @@ -30,16 +30,16 @@ class ToolChoice: @dataclass class ToolConfig: tools: List[Tool] - toolChoice: Optional[ToolChoice] = None + tool_choice: Optional[ToolChoice] = None def __post_init__(self): - msg = "Only one of 'auto', 'any', or 'tool' can be set in toolChoice" - if self.toolChoice and sum(bool(v) for v in vars(self.toolChoice).values()) != 1: + msg = "Only one of 'auto', 'any', or 'tool' can be set in tool_choice" + if self.tool_choice and sum(bool(v) for v in vars(self.tool_choice).values()) != 1: raise ValueError(msg) - if self.toolChoice and self.toolChoice.tool: - if "name" not in self.toolChoice.tool: - msg = "'name' is required when 'tool' is specified in toolChoice" + if self.tool_choice and self.tool_choice.tool: + if "name" not in self.tool_choice.tool: + msg = "'name' is required when 'tool' is specified in tool_choice" raise ValueError(msg) @staticmethod @@ -49,7 +49,7 @@ def from_functions(functions: List[Callable]) -> "ToolConfig": tool_spec = ToolSpec( name=func.__name__, description=func.__doc__, - inputSchema={ + input_schema={ "json": { "type": "object", "properties": {param: {"type": "string"} for param in inspect.signature(func).parameters}, @@ -57,17 +57,17 @@ def from_functions(functions: List[Callable]) -> "ToolConfig": } }, ) - tools.append(Tool(toolSpec=tool_spec)) + tools.append(Tool(tool_spec=tool_spec)) return ToolConfig(tools=tools) @classmethod def from_dict(cls, config: Dict) -> "ToolConfig": - tools = [Tool(ToolSpec(**tool["toolSpec"])) for tool in config.get("tools", [])] + tools = [Tool(ToolSpec(**tool["tool_spec"])) for tool in config.get("tools", [])] tool_choice = None - if "toolChoice" in config: - tc = config["toolChoice"] + if "tool_choice" in config: + tc = config["tool_choice"] if "auto" in tc: tool_choice = ToolChoice(auto=tc["auto"]) elif "any" in tc: @@ -75,30 +75,30 @@ def from_dict(cls, config: Dict) -> "ToolConfig": elif "tool" in tc: tool_choice = ToolChoice(tool={"name": tc["tool"]["name"]}) - return cls(tools=tools, toolChoice=tool_choice) + return cls(tools=tools, tool_choice=tool_choice) def to_dict(self) -> Dict[str, Any]: result = { "tools": [ { "toolSpec": { - "name": tool.toolSpec.name, - "description": tool.toolSpec.description, - "inputSchema": tool.toolSpec.inputSchema, + "name": tool.tool_spec.name, + "description": tool.tool_spec.description, + "inputSchema": tool.tool_spec.input_schema, } } for tool in self.tools ] } - if self.toolChoice: + if self.tool_choice: tool_choice: Dict[str, Dict[str, Any]] = {} - if self.toolChoice.auto: - tool_choice["auto"] = self.toolChoice.auto - elif self.toolChoice.any: - tool_choice["any"] = self.toolChoice.any - elif self.toolChoice.tool: - tool_choice["tool"] = self.toolChoice.tool - result["toolChoice"] = [tool_choice] + if self.tool_choice.auto: + tool_choice["auto"] = self.tool_choice.auto + elif self.tool_choice.any: + tool_choice["any"] = self.tool_choice.any + elif self.tool_choice.tool: + tool_choice["tool"] = self.tool_choice.tool + result["tool_choice"] = [tool_choice] return result @@ -142,14 +142,14 @@ class ToolResultContentBlock: @dataclass class ToolResultBlock: - toolUseId: str + tool_use_id: str content: List[ToolResultContentBlock] status: Optional[str] = None @dataclass class ToolUseBlock: - toolUseId: str + tool_use_id: str name: str input: Dict[str, Any] @@ -241,7 +241,13 @@ def from_dict(data: Dict[str, Any]) -> "ConverseMessage": elif "document" in item: content_blocks.append(DocumentBlock(**item["document"])) elif "toolUse" in item: - content_blocks.append(ToolUseBlock(**item["toolUse"])) + content_blocks.append( + ToolUseBlock( + tool_use_id=item["toolUse"]["toolUseId"], + name=item["toolUse"]["name"], + input=item["toolUse"]["input"], + ) + ) elif "toolResult" in item: content_blocks.append(ToolResultBlock(**item["toolResult"])) elif "guardContent" in item: @@ -285,7 +291,7 @@ def _handle_content_block_start(event: StreamEvent, current_index: int) -> Tuple start_of_tool_use = event.data.get("start") if start_of_tool_use: return new_index, ToolUseBlock( - toolUseId=start_of_tool_use["toolUse"]["toolUseId"], + tool_use_id=start_of_tool_use["toolUse"]["toolUseId"], name=start_of_tool_use["toolUse"]["name"], input={}, ) diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index 158af41b4..14af7ebb7 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -1,5 +1,5 @@ -from unittest.mock import Mock, patch import json +from unittest.mock import Mock, patch import pytest from botocore.exceptions import ClientError @@ -51,7 +51,8 @@ def test_to_dict(): ) expected_dict = { - "type": "haystack_integrations.components.generators.amazon_bedrock.converse.converse_generator.AmazonBedrockConverseGenerator", + "type": "haystack_integrations.components.generators.amazon_bedrock.converse.converse_generator." + "AmazonBedrockConverseGenerator", "init_parameters": { "aws_access_key_id": { "type": "env_var", @@ -102,7 +103,8 @@ def test_from_dict(): generator = AmazonBedrockConverseGenerator.from_dict( { - "type": "haystack_integrations.components.generators.amazon_bedrock.converse.converse_generator.AmazonBedrockConverseGenerator", + "type": "haystack_integrations.components.generators.amazon_bedrock.converse.converse_generator." + "AmazonBedrockConverseGenerator", "init_parameters": { "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, @@ -209,9 +211,6 @@ def test_run_with_different_message_types(mock_session): assert call_args["messages"][0]["content"] == [{"text": "What's the weather like?"}] -from unittest.mock import patch - - def test_streaming(): generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") @@ -293,7 +292,8 @@ def test_streaming(): assert len(chunks) == len(mocked_events) assert ( result["message"].content.content[0] - == "To answer your questions, I'll need to use two different functions: one to check the weather in Paris and another to get the current time in New York. Let me fetch that information for you." + == "To answer your questions, I'll need to use two different functions: one to check the weather " + "in Paris and another to get the current time in New York. Let me fetch that information for you." ) assert len(result["message"].content.content) == 3 assert result["stop_reason"] == "tool_use" @@ -333,7 +333,8 @@ def test_tool_usage(): "role": "assistant", "content": [ { - "text": "Certainly! I'd be happy to help you with the weather in Paris and the current time in New York. To get this information, I'll need to use two different tools. Let me fetch that data for you." + "text": "I'll get the weather in Paris and the current time in New York for you. " + "To do this, I'll need to use two different tools. Let me fetch that data." }, { "toolUse": { @@ -364,9 +365,9 @@ def test_tool_usage(): assert isinstance(result["message"].content.content[2], ToolUseBlock) assert result["stop_reason"] == "tool_use" assert result["message"].role == ConverseRole.ASSISTANT - assert ( - result["message"].content.content[0] - == "Certainly! I'd be happy to help you with the weather in Paris and the current time in New York. To get this information, I'll need to use two different tools. Let me fetch that data for you." + assert result["message"].content.content[0] == ( + "I'll get the weather in Paris and the current time in New York for you. " + "To do this, I'll need to use two different tools. Let me fetch that data." ) assert result["message"].content.content[1].name == "get_current_weather" assert result["message"].content.content[2].name == "get_current_time" From 933c19286aaeba7fbe65baffc08518701e5d92ce Mon Sep 17 00:00:00 2001 From: FloRul Date: Sat, 7 Sep 2024 21:50:03 -0400 Subject: [PATCH 27/35] fixing tests and serialization --- .../examples/converse_generator_example.py | 5 ++++- .../generators/amazon_bedrock/converse/utils.py | 14 +++++++++++++- .../tests/test_converse_generator.py | 2 +- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/integrations/amazon_bedrock/examples/converse_generator_example.py b/integrations/amazon_bedrock/examples/converse_generator_example.py index 1f0dd6d73..e7d588454 100644 --- a/integrations/amazon_bedrock/examples/converse_generator_example.py +++ b/integrations/amazon_bedrock/examples/converse_generator_example.py @@ -5,16 +5,18 @@ def get_current_weather(location: str, unit: str = "celsius") -> str: + """Get the current weather in a given location""" return f"The weather in {location} is 22 degrees {unit}." def get_current_time(timezone: str) -> str: + """Get the current time in a given timezone""" return f"The current time in {timezone} is 14:30." generator = AmazonBedrockConverseGenerator( model="anthropic.claude-3-5-sonnet-20240620-v1:0", - # streaming_callback=print, + streaming_callback=print, ) # Create ToolConfig from functions @@ -42,3 +44,4 @@ def get_current_time(timezone: str) -> str: }, }, ) +print(result) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 23e07b05d..5799878e7 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -63,7 +63,19 @@ def from_functions(functions: List[Callable]) -> "ToolConfig": @classmethod def from_dict(cls, config: Dict) -> "ToolConfig": - tools = [Tool(ToolSpec(**tool["tool_spec"])) for tool in config.get("tools", [])] + tools = [ + Tool( + ToolSpec( + input_schema=tool["toolSpec"]["inputSchema"], + name=tool["toolSpec"]["name"], + description=tool["toolSpec"]["description"], + ) + ) + for tool in config.get( + "tools", + [], + ) + ] tool_choice = None if "tool_choice" in config: diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index 14af7ebb7..f0bb6c32f 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -174,7 +174,7 @@ def test_get_model_capabilities(): assert generator.model_capabilities == MODEL_CAPABILITIES["anthropic.claude-3.*"] generator = AmazonBedrockConverseGenerator(model="ai21.j2-ultra-instruct-v1") - assert generator.model_capabilities == MODEL_CAPABILITIES["ai21.j2-.*-instruct"] + assert generator.model_capabilities == MODEL_CAPABILITIES["ai21.j2-.*"] with pytest.raises(ValueError, match="Unsupported model"): AmazonBedrockConverseGenerator(model="unsupported.model-v1") From b3904fc70002635a5ff8035127387d24e1ab071b Mon Sep 17 00:00:00 2001 From: FloRul Date: Sat, 7 Sep 2024 21:50:50 -0400 Subject: [PATCH 28/35] linter --- .../amazon_bedrock/examples/converse_generator_example.py | 1 - 1 file changed, 1 deletion(-) diff --git a/integrations/amazon_bedrock/examples/converse_generator_example.py b/integrations/amazon_bedrock/examples/converse_generator_example.py index e7d588454..82969b6c2 100644 --- a/integrations/amazon_bedrock/examples/converse_generator_example.py +++ b/integrations/amazon_bedrock/examples/converse_generator_example.py @@ -44,4 +44,3 @@ def get_current_time(timezone: str) -> str: }, }, ) -print(result) From 5a456b891d7b688dab86ff395ef00db9f653e9a4 Mon Sep 17 00:00:00 2001 From: FloRul Date: Sat, 7 Sep 2024 22:30:58 -0400 Subject: [PATCH 29/35] fix lint and tests --- .../converse/converse_generator.py | 23 +++++++++++++------ .../amazon_bedrock/converse/utils.py | 10 +------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 7fc7809ed..876c314fa 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -17,7 +17,14 @@ MODEL_CAPABILITIES, ModelCapability, ) -from .utils import ContentBlock, ConverseMessage, ConverseStreamingChunk, ImageBlock, ToolConfig, get_stream_message +from .utils import ( + ContentBlock, + ConverseMessage, + ImageBlock, + StreamEvent, + ToolConfig, + get_stream_message, +) logger = logging.getLogger(__name__) @@ -110,7 +117,7 @@ def __init__( aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008 inference_config: Optional[Dict[str, Any]] = None, tool_config: Optional[ToolConfig] = None, - streaming_callback: Optional[Callable[[ConverseStreamingChunk], None]] = None, + streaming_callback: Optional[Callable[[StreamEvent], None]] = None, system_prompt: Optional[List[Dict[str, Any]]] = None, ): """ @@ -191,7 +198,7 @@ def _get_model_capabilities(self, model: str) -> Set[ModelCapability]: def run( self, messages: List[ConverseMessage], - streaming_callback: Optional[Callable[[ConverseStreamingChunk], None]] = None, + streaming_callback: Optional[Callable[[StreamEvent], None]] = None, inference_config: Optional[Dict[str, Any]] = None, tool_config: Optional[ToolConfig] = None, system_prompt: Optional[List[Dict[str, Any]]] = None, @@ -248,7 +255,7 @@ def run( } if tool_config: - request_kwargs["toolConfig"] = tool_config + request_kwargs["toolConfig"] = tool_config.to_dict() if system_prompt: request_kwargs["system"] = { "text": system_prompt, @@ -258,18 +265,20 @@ def run( if streaming_callback and ModelCapability.CONVERSE_STREAM in self.model_capabilities: converse_response = self.client.converse_stream(**request_kwargs) response_stream = converse_response.get("stream") - message, metadata = get_stream_message(stream=response_stream, streaming_callback=streaming_callback) + message, metadata = get_stream_message( + stream=response_stream, + streaming_callback=streaming_callback, + ) else: converse_response = self.client.converse(**request_kwargs) output = converse_response.get("output") if output is None: response_output_missing_error = "Response does not contain 'output'" raise KeyError(response_output_missing_error) - message = output.get("message") + message = ConverseMessage.from_dict(output.get("message")) if message is None: response_output_missing_message_error = "Response 'output' does not contain 'message'" raise KeyError(response_output_missing_message_error) - message = ConverseMessage.from_dict(message) metadata = converse_response return { diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 5799878e7..927d815b7 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -277,14 +277,6 @@ def to_dict(self): } -@dataclass -class ConverseStreamingChunk: - content: Union[str, ToolUseBlock] - metadata: Dict[str, Any] - index: int = 0 - type: str = "" - - @dataclass class StreamEvent: type: str @@ -324,7 +316,7 @@ def _handle_content_block_delta( return current_block, current_tool_use_input_str + delta["toolUse"].get("input", "") else: return ToolUseBlock( - toolUseId=delta["toolUse"]["toolUseId"], + tool_use_id=delta["toolUse"]["toolUseId"], name=delta["toolUse"]["name"], input={}, ), delta["toolUse"].get("input", "") From ea0ff1e65db7b11bded2d55afc104d5b9fed5fbb Mon Sep 17 00:00:00 2001 From: FloRul Date: Sat, 7 Sep 2024 23:09:40 -0400 Subject: [PATCH 30/35] is it AWS_REGION over AWS_DEFAULT_REGION as specified in the pipeline ? --- .../amazon_bedrock/converse/converse_generator.py | 14 +++++++------- .../tests/test_converse_generator.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index 876c314fa..f014ad19c 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -327,13 +327,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockConverseGenerator": :returns: Deserialized component. """ - init_params = data.get("init_parameters", {}) - serialized_callback_handler = init_params.get("streaming_callback") - if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) - tool_config = data.get("init_parameters", {}).get("tool_config") - if tool_config: - data["init_parameters"]["tool_config"] = ToolConfig.from_dict(tool_config) deserialize_secrets_inplace( data["init_parameters"], [ @@ -344,4 +337,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockConverseGenerator": "aws_profile_name", ], ) + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + tool_config = data.get("init_parameters", {}).get("tool_config") + if tool_config: + data["init_parameters"]["tool_config"] = ToolConfig.from_dict(tool_config) return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index f0bb6c32f..0ca5dfd4a 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -71,7 +71,7 @@ def test_to_dict(): }, "aws_region_name": { "type": "env_var", - "env_vars": ["AWS_DEFAULT_REGION"], + "env_vars": ["AWS_REGION"], "strict": False, }, "aws_profile_name": { From 6b92ff3f985527b131324be72651be2e61e37b35 Mon Sep 17 00:00:00 2001 From: FloRul Date: Sat, 7 Sep 2024 23:17:01 -0400 Subject: [PATCH 31/35] Revert "is it AWS_REGION over AWS_DEFAULT_REGION as specified in the pipeline ?" This reverts commit ea0ff1e65db7b11bded2d55afc104d5b9fed5fbb. --- .../amazon_bedrock/converse/converse_generator.py | 14 +++++++------- .../tests/test_converse_generator.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py index f014ad19c..876c314fa 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/converse_generator.py @@ -327,6 +327,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockConverseGenerator": :returns: Deserialized component. """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + tool_config = data.get("init_parameters", {}).get("tool_config") + if tool_config: + data["init_parameters"]["tool_config"] = ToolConfig.from_dict(tool_config) deserialize_secrets_inplace( data["init_parameters"], [ @@ -337,11 +344,4 @@ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockConverseGenerator": "aws_profile_name", ], ) - init_params = data.get("init_parameters", {}) - serialized_callback_handler = init_params.get("streaming_callback") - if serialized_callback_handler: - data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) - tool_config = data.get("init_parameters", {}).get("tool_config") - if tool_config: - data["init_parameters"]["tool_config"] = ToolConfig.from_dict(tool_config) return default_from_dict(cls, data) diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index 0ca5dfd4a..f0bb6c32f 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -71,7 +71,7 @@ def test_to_dict(): }, "aws_region_name": { "type": "env_var", - "env_vars": ["AWS_REGION"], + "env_vars": ["AWS_DEFAULT_REGION"], "strict": False, }, "aws_profile_name": { From 43d442491401bff2d5581df60f5f35f75bbd3dac Mon Sep 17 00:00:00 2001 From: FloRul Date: Mon, 9 Sep 2024 10:35:17 -0400 Subject: [PATCH 32/35] Update test_converse_generator.py --- integrations/amazon_bedrock/tests/test_converse_generator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index f0bb6c32f..42a03737b 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -1,4 +1,5 @@ import json +import os from unittest.mock import Mock, patch import pytest @@ -33,7 +34,7 @@ def get_current_time(timezone: str) -> str: return f"The current time in {timezone} is 14:30." -def test_to_dict(): +def test_to_dict(mock_boto3_session): """ Test that the to_dict method returns the correct dictionary without aws credentials """ @@ -95,7 +96,7 @@ def test_to_dict(): assert generator.to_dict() == expected_dict -def test_from_dict(): +def test_from_dict(mock_boto3_session): """ Test that the from_dict method returns the correct object """ From b5b200418db46f14eb20e99768d22eda7aa4fb54 Mon Sep 17 00:00:00 2001 From: FloRul Date: Mon, 9 Sep 2024 10:37:47 -0400 Subject: [PATCH 33/35] remove unused os import --- integrations/amazon_bedrock/tests/test_converse_generator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index 42a03737b..7dd7ba381 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -1,5 +1,4 @@ import json -import os from unittest.mock import Mock, patch import pytest From 1f56a4aa8bc78fe48e1d19c6ba763a5463754eb0 Mon Sep 17 00:00:00 2001 From: FloRul Date: Mon, 9 Sep 2024 11:37:00 -0400 Subject: [PATCH 34/35] fix test fixtures --- .../components/generators/amazon_bedrock/converse/utils.py | 2 +- .../amazon_bedrock/tests/test_converse_generator.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py index 927d815b7..9ebf03806 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/converse/utils.py @@ -242,7 +242,7 @@ def from_user( @staticmethod def from_dict(data: Dict[str, Any]) -> "ConverseMessage": - role = ConverseRole(data["role"]) + role = ConverseRole.ASSISTANT if data["role"] == "assistant" else ConverseRole.USER content_blocks = [] for item in data["content"]: diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index 7dd7ba381..0760c88af 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -211,7 +211,7 @@ def test_run_with_different_message_types(mock_session): assert call_args["messages"][0]["content"] == [{"text": "What's the weather like?"}] -def test_streaming(): +def test_streaming(mock_boto3_session): generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") mocked_events = [ @@ -310,7 +310,7 @@ def test_streaming(): assert json.dumps(result["message"].content.content[2].input) == """{"timezone": "America/New_York"}""" -def test_client_error_handling(): +def test_client_error_handling(mock_boto3_session): generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") generator.client.converse = Mock( side_effect=ClientError( @@ -323,7 +323,7 @@ def test_client_error_handling(): generator.run([ConverseMessage.from_user(["Hi"])]) -def test_tool_usage(): +def test_tool_usage(mock_boto3_session): tool_config = ToolConfig.from_functions([get_current_weather, get_current_time]) generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", tool_config=tool_config) From 5f686c8ce4761e1dbb7cbcdbc402015ebde116f1 Mon Sep 17 00:00:00 2001 From: FloRul Date: Thu, 12 Sep 2024 22:03:58 -0400 Subject: [PATCH 35/35] Update test_converse_generator.py --- integrations/amazon_bedrock/tests/test_converse_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/amazon_bedrock/tests/test_converse_generator.py b/integrations/amazon_bedrock/tests/test_converse_generator.py index 0760c88af..af8e621d8 100644 --- a/integrations/amazon_bedrock/tests/test_converse_generator.py +++ b/integrations/amazon_bedrock/tests/test_converse_generator.py @@ -169,7 +169,7 @@ def test_constructor_with_empty_model(): AmazonBedrockConverseGenerator(model="") -def test_get_model_capabilities(): +def test_get_model_capabilities(mock_boto3_session): generator = AmazonBedrockConverseGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0") assert generator.model_capabilities == MODEL_CAPABILITIES["anthropic.claude-3.*"]