From b12461d40f907bdf5daf5862c16b4098aa3f9344 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 18 Dec 2024 15:41:16 +0100 Subject: [PATCH] fix: make Anthropic compatible with new `ChatMessage`; fix prompt caching tests (#1252) * make Anthropic compatible with new chatmessage; fix prompt caching tests * rm print --- .../generators/anthropic/chat/chat_generator.py | 13 ++++++++++--- integrations/anthropic/tests/test_chat_generator.py | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 43b50495c..56a740146 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -1,4 +1,3 @@ -import dataclasses import json from typing import Any, Callable, ClassVar, Dict, List, Optional, Union @@ -275,8 +274,16 @@ def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict """ anthropic_formatted_messages = [] for m in messages: - message_dict = dataclasses.asdict(m) - formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} + message_dict = m.to_dict() + formatted_message = {} + + # legacy format + if "role" in message_dict and "content" in message_dict: + formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} + # new format + elif "_role" in message_dict and "_content" in message_dict: + formatted_message = {"role": m.role.value, "content": m.text} + if m.is_from(ChatRole.SYSTEM): # system messages are treated differently and MUST be in the format expected by the Anthropic API # remove role and content from the message dict, add type and text diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index 9a111fc9d..36622ecd9 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -428,5 +428,5 @@ def test_prompt_caching(self, cache_enabled): or token_usage.get("cache_read_input_tokens") > 1024 ) else: - assert "cache_creation_input_tokens" not in token_usage - assert "cache_read_input_tokens" not in token_usage + assert token_usage["cache_creation_input_tokens"] == 0 + assert token_usage["cache_read_input_tokens"] == 0