From 9272da82aee14e346760ca6ec6bf8321015e110e Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 6 Mar 2024 10:48:09 +0100 Subject: [PATCH] Migrate Claude to messaging API --- .../amazon_bedrock/chat/adapters.py | 100 ++++++++---------- .../amazon_bedrock/chat/chat_generator.py | 13 +-- 2 files changed, 53 insertions(+), 60 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 196a55743..6d8bb8e30 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -44,7 +44,7 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: :param response_body: The response body. :returns: The extracted responses. """ - return self._extract_messages_from_response(self.response_body_message_key(), response_body) + return self._extract_messages_from_response(response_body) def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]: tokens: List[str] = [] @@ -53,11 +53,8 @@ def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[St if chunk: decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) token = self._extract_token_from_stream(decoded_chunk) - # take all the rest key/value pairs from the chunk, add them to the metadata - stream_metadata = {k: v for (k, v) in decoded_chunk.items() if v != token} - stream_chunk = StreamingChunk(content=token, meta=stream_metadata) - # callback the stream handler with StreamingChunk - stream_handler(stream_chunk) + stream_chunk = StreamingChunk(content=token) # don't extract meta, we care about tokens only + stream_handler(stream_chunk) # callback the stream handler with StreamingChunk tokens.append(token) responses = ["".join(tokens).lstrip()] return responses @@ -124,25 +121,14 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: :returns: A dictionary containing the resized prompt and additional information. """ - def _extract_messages_from_response(self, message_tag: str, response_body: Dict[str, Any]) -> List[ChatMessage]: + @abstractmethod + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ Extracts the messages from the response body. - :param message_tag: The key for the message in the response body. :param response_body: The response body. :returns: The extracted ChatMessage list. """ - metadata = {k: v for (k, v) in response_body.items() if k != message_tag} - return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] - - @abstractmethod - def response_body_message_key(self) -> str: - """ - Returns the key for the message in the response body. - Subclasses should override this method to return the correct message key - where the response is located. - - :returns: The key for the message in the response body. - """ @abstractmethod def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: @@ -183,7 +169,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]): self.prompt_handler = DefaultPromptHandler( tokenizer="gpt2", model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_tokens_to_sample") or 512, + max_length=self.generation_kwargs.get("max_tokens") or 512, ) def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: @@ -195,8 +181,8 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ :returns: The prepared body. """ default_params = { - "max_tokens_to_sample": self.generation_kwargs.get("max_tokens_to_sample") or 512, - "stop_sequences": ["\n\nHuman:"], + "anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31", + "max_tokens": self.generation_kwargs.get("max_tokens") or 512, } # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it @@ -204,37 +190,24 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ if stop_sequences: inference_kwargs["stop_sequences"] = stop_sequences params = self._get_params(inference_kwargs, default_params) - body = {"prompt": self.prepare_chat_messages(messages=messages), **params} + body = {**self.prepare_chat_messages(messages=messages), **params} return body - def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: + def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: """ Prepares the chat messages for the Anthropic Claude request. :param messages: The chat messages to prepare. :returns: The prepared chat messages as a string. """ - conversation = [] - for index, message in enumerate(messages): - if message.is_from(ChatRole.USER): - conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_USER_TOKEN} {message.content.strip()}") - elif message.is_from(ChatRole.ASSISTANT): - conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN} {message.content.strip()}") - elif message.is_from(ChatRole.FUNCTION): - error_message = "Anthropic does not support function calls." - raise ValueError(error_message) - elif message.is_from(ChatRole.SYSTEM) and index == 0: - # Until we transition to the new chat message format system messages will be ignored - # see https://docs.anthropic.com/claude/reference/messages_post for more details - logger.warning( - "System messages are not fully supported by the current version of Claude and will be ignored." - ) - else: - invalid_role = f"Invalid role {message.role} for message {message.content}" - raise ValueError(invalid_role) - - prepared_prompt = "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " " - return self._ensure_token_limit(prepared_prompt) + body: Dict[str, Any] = {} + system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None + body["messages"] = [ + self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT) + ] + if system: + body["system"] = system + return body def check_prompt(self, prompt: str) -> Dict[str, Any]: """ @@ -245,13 +218,19 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: """ return self.prompt_handler(prompt) - def response_body_message_key(self) -> str: + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ - Returns the key for the message in the response body for Anthropic Claude i.e. "completion". + Extracts the messages from the response body. - :returns: The key for the message in the response body. + :param response_body: The response body. + :return: The extracted ChatMessage list. """ - return "completion" + messages: List[ChatMessage] = [] + if response_body.get("type") == "message": + for content in response_body["content"]: + if content.get("type") == "text": + messages.append(ChatMessage.from_assistant(content["text"])) + return messages def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: """ @@ -260,7 +239,17 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: :param chunk: The streaming chunk. :returns: The extracted token. """ - return chunk.get("completion", "") + if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta": + return chunk.get("delta", {}).get("text", "") + return "" + + def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]: + """ + Convert a ChatMessage to a dictionary with the content and role fields. + :param m: The ChatMessage to convert. + :return: The dictionary with the content and role fields. + """ + return {"content": [{"type": "text", "text": m.content}], "role": m.role.value} class MetaLlama2ChatAdapter(BedrockModelChatAdapter): @@ -357,13 +346,16 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: """ return self.prompt_handler(prompt) - def response_body_message_key(self) -> str: + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ - Returns the key for the message in the response body for Meta Llama 2 i.e. "generation". + Extracts the messages from the response body. - :returns: The key for the message in the response body. + :param response_body: The response body. + :return: The extracted ChatMessage list. """ - return "generation" + message_tag = "generation" + metadata = {k: v for (k, v) in response_body.items() if k != message_tag} + return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: """ diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index bea6924f6..5279dc001 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -25,20 +25,21 @@ class AmazonBedrockChatGenerator: """ `AmazonBedrockChatGenerator` enables text generation via Amazon Bedrock hosted chat LLMs. - For example, to use the Anthropic Claude model, simply initialize the `AmazonBedrockChatGenerator` with the - 'anthropic.claude-v2' model name. + For example, to use the Anthropic Claude 3 Sonnet model, simply initialize the `AmazonBedrockChatGenerator` with the + 'anthropic.claude-3-sonnet-20240229-v1:0' model name. ```python from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator from haystack.dataclasses import ChatMessage from haystack.components.generators.utils import print_streaming_chunk - messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"), + messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant, answer in German only"), ChatMessage.from_user("What's Natural Language Processing?")] - client = AmazonBedrockChatGenerator(model="anthropic.claude-v2", streaming_callback=print_streaming_chunk) - client.run(messages, generation_kwargs={"max_tokens_to_sample": 512}) + client = AmazonBedrockChatGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", + streaming_callback=print_streaming_chunk) + client.run(messages, generation_kwargs={"max_tokens": 512}) ``` @@ -154,7 +155,7 @@ def invoke(self, *args, **kwargs): msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." raise ValueError(msg) - body = self.model_adapter.prepare_body(messages=messages, stop_words=self.stop_words, **kwargs) + body = self.model_adapter.prepare_body(messages=messages, **{"stop_words": self.stop_words, **kwargs}) try: if self.streaming_callback: response = self.client.invoke_model_with_response_stream(