From 36f16c1979f5df72d4e7e6fe595e91148e00daf8 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 20 Sep 2024 09:49:33 +0200 Subject: [PATCH] feat: Add Anthropic prompt caching support, add example (#1006) * Add prompt caching, add example * Print prompt caching data in example * Lint * Anthropic allows multiple system messages, simplify * PR feedback * Update prompt_caching.py example to use ChatPromptBuilder 2.5 fixes * Small fixes * Add unit tests * Improve UX for prompt caching example * Add unit test for _convert_to_anthropic_format * More integration tests * Update test to turn on/off prompt cache --- .../anthropic/example/prompt_caching.py | 102 +++++++++++ integrations/anthropic/pyproject.toml | 1 + .../anthropic/chat/chat_generator.py | 54 ++++-- .../anthropic/tests/test_chat_generator.py | 168 ++++++++++++++++++ 4 files changed, 314 insertions(+), 11 deletions(-) create mode 100644 integrations/anthropic/example/prompt_caching.py diff --git a/integrations/anthropic/example/prompt_caching.py b/integrations/anthropic/example/prompt_caching.py new file mode 100644 index 000000000..d8cc0f0e8 --- /dev/null +++ b/integrations/anthropic/example/prompt_caching.py @@ -0,0 +1,102 @@ +# To run this example, you will need to set a `ANTHROPIC_API_KEY` environment variable. + +import time + +from haystack import Pipeline +from haystack.components.builders import ChatPromptBuilder +from haystack.components.converters import HTMLToDocument +from haystack.components.fetchers import LinkContentFetcher +from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.utils import Secret + +from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator + +# Advanced: We can also cache the HTTP GET requests for the HTML content to avoid repeating requests +# that fetched the same content. +# This type of caching requires requests_cache library to be installed +# Uncomment the following two lines to caching the HTTP requests + +# import requests_cache +# requests_cache.install_cache("anthropic_demo") + +ENABLE_PROMPT_CACHING = True # Toggle this to enable or disable prompt caching + + +def create_streaming_callback(): + first_token_time = None + + def stream_callback(chunk: StreamingChunk) -> None: + nonlocal first_token_time + if first_token_time is None: + first_token_time = time.time() + print(chunk.content, flush=True, end="") + + return stream_callback, lambda: first_token_time + + +# Until prompt caching graduates from beta, we need to set the anthropic-beta header +generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if ENABLE_PROMPT_CACHING else {} + +claude_llm = AnthropicChatGenerator( + api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), generation_kwargs=generation_kwargs +) + +pipe = Pipeline() +pipe.add_component("fetcher", LinkContentFetcher()) +pipe.add_component("converter", HTMLToDocument()) +pipe.add_component("prompt_builder", ChatPromptBuilder(variables=["documents"])) +pipe.add_component("llm", claude_llm) +pipe.connect("fetcher", "converter") +pipe.connect("converter", "prompt_builder") +pipe.connect("prompt_builder.prompt", "llm.messages") + +system_message = ChatMessage.from_system( + "Claude is an AI assistant that answers questions based on the given documents.\n" + "Here are the documents:\n" + "{% for d in documents %} \n" + " {{d.content}} \n" + "{% endfor %}" +) + +if ENABLE_PROMPT_CACHING: + system_message.meta["cache_control"] = {"type": "ephemeral"} + +questions = [ + "What's this paper about?", + "What's the main contribution of this paper?", + "How can findings from this paper be applied to real-world problems?", +] + +for question in questions: + print(f"Question: {question}") + start_time = time.time() + streaming_callback, get_first_token_time = create_streaming_callback() + # reset LLM streaming callback to initialize new timers in streaming mode + claude_llm.streaming_callback = streaming_callback + + result = pipe.run( + data={ + "fetcher": {"urls": ["https://ar5iv.labs.arxiv.org/html/2310.04406"]}, + "prompt_builder": {"template": [system_message, ChatMessage.from_user(f"Answer the question: {question}")]}, + } + ) + + end_time = time.time() + total_time = end_time - start_time + time_to_first_token = get_first_token_time() - start_time + print("\n" + "-" * 100) + print(f"Total generation time: {total_time:.2f} seconds") + print(f"Time to first token: {time_to_first_token:.2f} seconds") + # first time we create a prompt cache usage key 'cache_creation_input_tokens' will have a value of the number of + # tokens used to create the prompt cache + # on first subsequent cache hit we'll see a usage key 'cache_read_input_tokens' having a value of the number of + # tokens read from the cache + token_stats = result["llm"]["replies"][0].meta.get("usage") + if token_stats.get("cache_creation_input_tokens", 0) > 0: + print("Cache created! ", end="") + elif token_stats.get("cache_read_input_tokens", 0) > 0: + print("Cache hit! ", end="") + else: + print("Cache not used, something is wrong with the prompt caching setup. ", end="") + print(f"Cache usage details: {token_stats}") + print("\n" + "=" * 100) diff --git a/integrations/anthropic/pyproject.toml b/integrations/anthropic/pyproject.toml index 3f8c9812b..e1d3fa867 100644 --- a/integrations/anthropic/pyproject.toml +++ b/integrations/anthropic/pyproject.toml @@ -106,6 +106,7 @@ select = [ "YTT", ] ignore = [ + "T201", # print statements # Allow non-abstract empty methods in abstract base classes "B027", # Ignore checks for possible passwords diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 9954f08c5..43b50495c 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -72,6 +72,7 @@ class AnthropicChatGenerator: "temperature", "top_p", "top_k", + "extra_headers", ] def __init__( @@ -101,6 +102,7 @@ def __init__( - `temperature`: The temperature to use for sampling. - `top_p`: The top_p value to use for nucleus sampling. - `top_k`: The top_k value to use for top-k sampling. + - `extra_headers`: A dictionary of extra headers to be passed to the model (i.e. for beta features). :param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a "chain of thought" messages before returning the actual function names and parameters in a message. If `ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool @@ -177,20 +179,35 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, f"Model parameters {disallowed_params} are not allowed and will be ignored. " f"Allowed parameters are {self.ALLOWED_PARAMS}." ) + system_messages: List[ChatMessage] = [msg for msg in messages if msg.is_from(ChatRole.SYSTEM)] + non_system_messages: List[ChatMessage] = [msg for msg in messages if not msg.is_from(ChatRole.SYSTEM)] + system_messages_formatted: List[Dict[str, Any]] = ( + self._convert_to_anthropic_format(system_messages) if system_messages else [] + ) + messages_formatted: List[Dict[str, Any]] = ( + self._convert_to_anthropic_format(non_system_messages) if non_system_messages else [] + ) - # adapt ChatMessage(s) to the format expected by the Anthropic API - anthropic_formatted_messages = self._convert_to_anthropic_format(messages) - - # system message provided by the user overrides the system message from the self.generation_kwargs - system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None - if system: - anthropic_formatted_messages = anthropic_formatted_messages[1:] + extra_headers = filtered_generation_kwargs.get("extra_headers", {}) + prompt_caching_on = "anthropic-beta" in extra_headers and "prompt-caching" in extra_headers["anthropic-beta"] + has_cached_messages = any("cache_control" in m for m in system_messages_formatted) or any( + "cache_control" in m for m in messages_formatted + ) + if has_cached_messages and not prompt_caching_on: + # this avoids Anthropic errors when prompt caching is not enabled + # but user requested individual messages to be cached + logger.warn( + "Prompt caching is not enabled but you requested individual messages to be cached. " + "Messages will be sent to the API without prompt caching." + ) + system_messages_formatted = list(map(self._remove_cache_control, system_messages_formatted)) + messages_formatted = list(map(self._remove_cache_control, messages_formatted)) response: Union[Message, Stream[MessageStreamEvent]] = self.client.messages.create( max_tokens=filtered_generation_kwargs.pop("max_tokens", 512), - system=system if system else filtered_generation_kwargs.pop("system", ""), + system=system_messages_formatted or filtered_generation_kwargs.pop("system", ""), model=self.model, - messages=anthropic_formatted_messages, + messages=messages_formatted, stream=self.streaming_callback is not None, **filtered_generation_kwargs, ) @@ -259,8 +276,15 @@ def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict anthropic_formatted_messages = [] for m in messages: message_dict = dataclasses.asdict(m) - filtered_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} - anthropic_formatted_messages.append(filtered_message) + formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} + if m.is_from(ChatRole.SYSTEM): + # system messages are treated differently and MUST be in the format expected by the Anthropic API + # remove role and content from the message dict, add type and text + formatted_message.pop("role") + formatted_message["type"] = "text" + formatted_message["text"] = formatted_message.pop("content") + formatted_message.update(m.meta or {}) + anthropic_formatted_messages.append(formatted_message) return anthropic_formatted_messages def _connect_chunks( @@ -291,3 +315,11 @@ def _build_chunk(self, delta: TextDelta) -> StreamingChunk: :returns: The StreamingChunk. """ return StreamingChunk(content=delta.text) + + def _remove_cache_control(self, message: Dict[str, Any]) -> Dict[str, Any]: + """ + Removes the cache_control key from the message. + :param message: The message to remove the cache_control key from. + :returns: The message with the cache_control key removed. + """ + return {k: v for k, v in message.items() if k != "cache_control"} diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index 3ffa24c94..155cf7950 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -262,3 +262,171 @@ def test_tools_use(self): fc_response = json.loads(first_reply.content) assert "name" in fc_response, "First reply does not contain name of the tool" assert "input" in fc_response, "First reply does not contain input of the tool" + + def test_prompt_caching_enabled(self, monkeypatch): + """ + Test that the generation_kwargs extra_headers are correctly passed to the Anthropic API when prompt + caching is enabled + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + component = AnthropicChatGenerator( + generation_kwargs={"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} + ) + assert component.generation_kwargs.get("extra_headers", {}).get("anthropic-beta") == "prompt-caching-2024-07-31" + + def test_prompt_caching_cache_control_without_extra_headers(self, monkeypatch, mock_chat_completion, caplog): + """ + Test that the cache_control is removed from the messages when prompt caching is not enabled via extra_headers + This is to avoid Anthropic errors when prompt caching is not enabled + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + component = AnthropicChatGenerator() + + messages = [ChatMessage.from_system("System message"), ChatMessage.from_user("User message")] + + # Add cache_control to messages + for msg in messages: + msg.meta["cache_control"] = {"type": "ephemeral"} + + # Invoke run with messages + component.run(messages) + + # Check caplog for the warning message that should have been logged + assert any("Prompt caching" in record.message for record in caplog.records) + + # Check that the Anthropic API was called without cache_control in messages so that it does not raise an error + _, kwargs = mock_chat_completion.call_args + for msg in kwargs["messages"]: + assert "cache_control" not in msg + + @pytest.mark.parametrize("enable_caching", [True, False]) + def test_run_with_prompt_caching(self, monkeypatch, mock_chat_completion, enable_caching): + """ + Test that the generation_kwargs extra_headers are correctly passed to the Anthropic API in both cases of + prompt caching being enabled or not + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + + generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if enable_caching else {} + component = AnthropicChatGenerator(generation_kwargs=generation_kwargs) + + messages = [ChatMessage.from_system("System message"), ChatMessage.from_user("User message")] + + component.run(messages) + + # Check that the Anthropic API was called with the correct headers + _, kwargs = mock_chat_completion.call_args + headers = kwargs.get("extra_headers", {}) + if enable_caching: + assert "anthropic-beta" in headers + else: + assert "anthropic-beta" not in headers + + def test_to_dict_with_prompt_caching(self, monkeypatch): + """ + Test that the generation_kwargs extra_headers are correctly serialized to a dictionary + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + component = AnthropicChatGenerator( + generation_kwargs={"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} + ) + data = component.to_dict() + assert ( + data["init_parameters"]["generation_kwargs"]["extra_headers"]["anthropic-beta"] + == "prompt-caching-2024-07-31" + ) + + def test_from_dict_with_prompt_caching(self, monkeypatch): + """ + Test that the generation_kwargs extra_headers are correctly deserialized from a dictionary + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + data = { + "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, + "model": "claude-3-5-sonnet-20240620", + "generation_kwargs": {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}}, + }, + } + component = AnthropicChatGenerator.from_dict(data) + assert component.generation_kwargs["extra_headers"]["anthropic-beta"] == "prompt-caching-2024-07-31" + + def test_convert_messages_to_anthropic_format(self, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + generator = AnthropicChatGenerator() + + # Test scenario 1: Regular user and assistant messages + messages = [ + ChatMessage.from_user("Hello"), + ChatMessage.from_assistant("Hi there!"), + ] + result = generator._convert_to_anthropic_format(messages) + assert result == [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + # Test scenario 2: System message + messages = [ChatMessage.from_system("You are a helpful assistant.")] + result = generator._convert_to_anthropic_format(messages) + assert result == [{"type": "text", "text": "You are a helpful assistant."}] + + # Test scenario 3: Mixed message types + messages = [ + ChatMessage.from_system("Be concise."), + ChatMessage.from_user("What's AI?"), + ChatMessage.from_assistant("Artificial Intelligence."), + ] + result = generator._convert_to_anthropic_format(messages) + assert result == [ + {"type": "text", "text": "Be concise."}, + {"role": "user", "content": "What's AI?"}, + {"role": "assistant", "content": "Artificial Intelligence."}, + ] + + # Test scenario 4: metadata + messages = [ + ChatMessage.from_user("What's AI?"), + ChatMessage.from_assistant("Artificial Intelligence.", meta={"confidence": 0.9}), + ] + result = generator._convert_to_anthropic_format(messages) + assert result == [ + {"role": "user", "content": "What's AI?"}, + {"role": "assistant", "content": "Artificial Intelligence.", "confidence": 0.9}, + ] + + # Test scenario 5: Empty message list + assert generator._convert_to_anthropic_format([]) == [] + + @pytest.mark.integration + @pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY", None), reason="ANTHROPIC_API_KEY not set") + @pytest.mark.parametrize("cache_enabled", [True, False]) + def test_prompt_caching(self, cache_enabled): + generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if cache_enabled else {} + + claude_llm = AnthropicChatGenerator( + api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), generation_kwargs=generation_kwargs + ) + + # see https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#cache-limitations + system_message = ChatMessage.from_system("This is the cached, here we make it at least 1024 tokens long." * 70) + if cache_enabled: + system_message.meta["cache_control"] = {"type": "ephemeral"} + + messages = [system_message, ChatMessage.from_user("What's in cached content?")] + result = claude_llm.run(messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + token_usage = result["replies"][0].meta.get("usage") + + if cache_enabled: + # either we created cache or we read it (depends on how you execute this integration test) + assert ( + token_usage.get("cache_creation_input_tokens") > 1024 + or token_usage.get("cache_read_input_tokens") > 1024 + ) + else: + assert "cache_creation_input_tokens" not in token_usage + assert "cache_read_input_tokens" not in token_usage