From d50a7cdf5cabd7f32cadfbc54dd01ca779bac394 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 18 Dec 2024 11:52:54 +0100 Subject: [PATCH 1/2] make Anthropic compatible with new chatmessage; fix prompt caching tests --- .../generators/anthropic/chat/chat_generator.py | 13 ++++++++++--- integrations/anthropic/tests/test_chat_generator.py | 6 ++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 43b50495c..56a740146 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -1,4 +1,3 @@ -import dataclasses import json from typing import Any, Callable, ClassVar, Dict, List, Optional, Union @@ -275,8 +274,16 @@ def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict """ anthropic_formatted_messages = [] for m in messages: - message_dict = dataclasses.asdict(m) - formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} + message_dict = m.to_dict() + formatted_message = {} + + # legacy format + if "role" in message_dict and "content" in message_dict: + formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} + # new format + elif "_role" in message_dict and "_content" in message_dict: + formatted_message = {"role": m.role.value, "content": m.text} + if m.is_from(ChatRole.SYSTEM): # system messages are treated differently and MUST be in the format expected by the Anthropic API # remove role and content from the message dict, add type and text diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index 9a111fc9d..69b3265aa 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -421,6 +421,8 @@ def test_prompt_caching(self, cache_enabled): assert len(result["replies"]) == 1 token_usage = result["replies"][0].meta.get("usage") + print(token_usage) + if cache_enabled: # either we created cache or we read it (depends on how you execute this integration test) assert ( @@ -428,5 +430,5 @@ def test_prompt_caching(self, cache_enabled): or token_usage.get("cache_read_input_tokens") > 1024 ) else: - assert "cache_creation_input_tokens" not in token_usage - assert "cache_read_input_tokens" not in token_usage + assert token_usage["cache_creation_input_tokens"] == 0 + assert token_usage["cache_read_input_tokens"] == 0 From 8e3d6f187ff350d63932f89ec9e17fb869f37814 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 18 Dec 2024 14:47:25 +0100 Subject: [PATCH 2/2] rm print --- integrations/anthropic/tests/test_chat_generator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index 69b3265aa..36622ecd9 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -421,8 +421,6 @@ def test_prompt_caching(self, cache_enabled): assert len(result["replies"]) == 1 token_usage = result["replies"][0].meta.get("usage") - print(token_usage) - if cache_enabled: # either we created cache or we read it (depends on how you execute this integration test) assert (