Skip to content

Commit

Permalink
feat: Add Anthropic prompt caching support, add example (#1006)
Browse files Browse the repository at this point in the history
* Add prompt caching, add example

* Print prompt caching data in example

* Lint

* Anthropic allows multiple system messages, simplify

* PR feedback

* Update prompt_caching.py example to use ChatPromptBuilder 2.5 fixes

* Small fixes

* Add unit tests

* Improve UX for prompt caching example

* Add unit test for _convert_to_anthropic_format

* More integration tests

* Update test to turn on/off prompt cache
  • Loading branch information
vblagoje authored Sep 20, 2024
1 parent c0750ec commit 36f16c1
Show file tree
Hide file tree
Showing 4 changed files with 314 additions and 11 deletions.
102 changes: 102 additions & 0 deletions integrations/anthropic/example/prompt_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# To run this example, you will need to set a `ANTHROPIC_API_KEY` environment variable.

import time

from haystack import Pipeline
from haystack.components.builders import ChatPromptBuilder
from haystack.components.converters import HTMLToDocument
from haystack.components.fetchers import LinkContentFetcher
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils import Secret

from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator

# Advanced: We can also cache the HTTP GET requests for the HTML content to avoid repeating requests
# that fetched the same content.
# This type of caching requires requests_cache library to be installed
# Uncomment the following two lines to caching the HTTP requests

# import requests_cache
# requests_cache.install_cache("anthropic_demo")

ENABLE_PROMPT_CACHING = True # Toggle this to enable or disable prompt caching


def create_streaming_callback():
first_token_time = None

def stream_callback(chunk: StreamingChunk) -> None:
nonlocal first_token_time
if first_token_time is None:
first_token_time = time.time()
print(chunk.content, flush=True, end="")

return stream_callback, lambda: first_token_time


# Until prompt caching graduates from beta, we need to set the anthropic-beta header
generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if ENABLE_PROMPT_CACHING else {}

claude_llm = AnthropicChatGenerator(
api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), generation_kwargs=generation_kwargs
)

pipe = Pipeline()
pipe.add_component("fetcher", LinkContentFetcher())
pipe.add_component("converter", HTMLToDocument())
pipe.add_component("prompt_builder", ChatPromptBuilder(variables=["documents"]))
pipe.add_component("llm", claude_llm)
pipe.connect("fetcher", "converter")
pipe.connect("converter", "prompt_builder")
pipe.connect("prompt_builder.prompt", "llm.messages")

system_message = ChatMessage.from_system(
"Claude is an AI assistant that answers questions based on the given documents.\n"
"Here are the documents:\n"
"{% for d in documents %} \n"
" {{d.content}} \n"
"{% endfor %}"
)

if ENABLE_PROMPT_CACHING:
system_message.meta["cache_control"] = {"type": "ephemeral"}

questions = [
"What's this paper about?",
"What's the main contribution of this paper?",
"How can findings from this paper be applied to real-world problems?",
]

for question in questions:
print(f"Question: {question}")
start_time = time.time()
streaming_callback, get_first_token_time = create_streaming_callback()
# reset LLM streaming callback to initialize new timers in streaming mode
claude_llm.streaming_callback = streaming_callback

result = pipe.run(
data={
"fetcher": {"urls": ["https://ar5iv.labs.arxiv.org/html/2310.04406"]},
"prompt_builder": {"template": [system_message, ChatMessage.from_user(f"Answer the question: {question}")]},
}
)

end_time = time.time()
total_time = end_time - start_time
time_to_first_token = get_first_token_time() - start_time
print("\n" + "-" * 100)
print(f"Total generation time: {total_time:.2f} seconds")
print(f"Time to first token: {time_to_first_token:.2f} seconds")
# first time we create a prompt cache usage key 'cache_creation_input_tokens' will have a value of the number of
# tokens used to create the prompt cache
# on first subsequent cache hit we'll see a usage key 'cache_read_input_tokens' having a value of the number of
# tokens read from the cache
token_stats = result["llm"]["replies"][0].meta.get("usage")
if token_stats.get("cache_creation_input_tokens", 0) > 0:
print("Cache created! ", end="")
elif token_stats.get("cache_read_input_tokens", 0) > 0:
print("Cache hit! ", end="")
else:
print("Cache not used, something is wrong with the prompt caching setup. ", end="")
print(f"Cache usage details: {token_stats}")
print("\n" + "=" * 100)
1 change: 1 addition & 0 deletions integrations/anthropic/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ select = [
"YTT",
]
ignore = [
"T201", # print statements
# Allow non-abstract empty methods in abstract base classes
"B027",
# Ignore checks for possible passwords
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class AnthropicChatGenerator:
"temperature",
"top_p",
"top_k",
"extra_headers",
]

def __init__(
Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
- `temperature`: The temperature to use for sampling.
- `top_p`: The top_p value to use for nucleus sampling.
- `top_k`: The top_k value to use for top-k sampling.
- `extra_headers`: A dictionary of extra headers to be passed to the model (i.e. for beta features).
:param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a
"chain of thought" messages before returning the actual function names and parameters in a message. If
`ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool
Expand Down Expand Up @@ -177,20 +179,35 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
f"Model parameters {disallowed_params} are not allowed and will be ignored. "
f"Allowed parameters are {self.ALLOWED_PARAMS}."
)
system_messages: List[ChatMessage] = [msg for msg in messages if msg.is_from(ChatRole.SYSTEM)]
non_system_messages: List[ChatMessage] = [msg for msg in messages if not msg.is_from(ChatRole.SYSTEM)]
system_messages_formatted: List[Dict[str, Any]] = (
self._convert_to_anthropic_format(system_messages) if system_messages else []
)
messages_formatted: List[Dict[str, Any]] = (
self._convert_to_anthropic_format(non_system_messages) if non_system_messages else []
)

# adapt ChatMessage(s) to the format expected by the Anthropic API
anthropic_formatted_messages = self._convert_to_anthropic_format(messages)

# system message provided by the user overrides the system message from the self.generation_kwargs
system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None
if system:
anthropic_formatted_messages = anthropic_formatted_messages[1:]
extra_headers = filtered_generation_kwargs.get("extra_headers", {})
prompt_caching_on = "anthropic-beta" in extra_headers and "prompt-caching" in extra_headers["anthropic-beta"]
has_cached_messages = any("cache_control" in m for m in system_messages_formatted) or any(
"cache_control" in m for m in messages_formatted
)
if has_cached_messages and not prompt_caching_on:
# this avoids Anthropic errors when prompt caching is not enabled
# but user requested individual messages to be cached
logger.warn(
"Prompt caching is not enabled but you requested individual messages to be cached. "
"Messages will be sent to the API without prompt caching."
)
system_messages_formatted = list(map(self._remove_cache_control, system_messages_formatted))
messages_formatted = list(map(self._remove_cache_control, messages_formatted))

response: Union[Message, Stream[MessageStreamEvent]] = self.client.messages.create(
max_tokens=filtered_generation_kwargs.pop("max_tokens", 512),
system=system if system else filtered_generation_kwargs.pop("system", ""),
system=system_messages_formatted or filtered_generation_kwargs.pop("system", ""),
model=self.model,
messages=anthropic_formatted_messages,
messages=messages_formatted,
stream=self.streaming_callback is not None,
**filtered_generation_kwargs,
)
Expand Down Expand Up @@ -259,8 +276,15 @@ def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict
anthropic_formatted_messages = []
for m in messages:
message_dict = dataclasses.asdict(m)
filtered_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v}
anthropic_formatted_messages.append(filtered_message)
formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v}
if m.is_from(ChatRole.SYSTEM):
# system messages are treated differently and MUST be in the format expected by the Anthropic API
# remove role and content from the message dict, add type and text
formatted_message.pop("role")
formatted_message["type"] = "text"
formatted_message["text"] = formatted_message.pop("content")
formatted_message.update(m.meta or {})
anthropic_formatted_messages.append(formatted_message)
return anthropic_formatted_messages

def _connect_chunks(
Expand Down Expand Up @@ -291,3 +315,11 @@ def _build_chunk(self, delta: TextDelta) -> StreamingChunk:
:returns: The StreamingChunk.
"""
return StreamingChunk(content=delta.text)

def _remove_cache_control(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""
Removes the cache_control key from the message.
:param message: The message to remove the cache_control key from.
:returns: The message with the cache_control key removed.
"""
return {k: v for k, v in message.items() if k != "cache_control"}
168 changes: 168 additions & 0 deletions integrations/anthropic/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,171 @@ def test_tools_use(self):
fc_response = json.loads(first_reply.content)
assert "name" in fc_response, "First reply does not contain name of the tool"
assert "input" in fc_response, "First reply does not contain input of the tool"

def test_prompt_caching_enabled(self, monkeypatch):
"""
Test that the generation_kwargs extra_headers are correctly passed to the Anthropic API when prompt
caching is enabled
"""
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key")
component = AnthropicChatGenerator(
generation_kwargs={"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}}
)
assert component.generation_kwargs.get("extra_headers", {}).get("anthropic-beta") == "prompt-caching-2024-07-31"

def test_prompt_caching_cache_control_without_extra_headers(self, monkeypatch, mock_chat_completion, caplog):
"""
Test that the cache_control is removed from the messages when prompt caching is not enabled via extra_headers
This is to avoid Anthropic errors when prompt caching is not enabled
"""
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key")
component = AnthropicChatGenerator()

messages = [ChatMessage.from_system("System message"), ChatMessage.from_user("User message")]

# Add cache_control to messages
for msg in messages:
msg.meta["cache_control"] = {"type": "ephemeral"}

# Invoke run with messages
component.run(messages)

# Check caplog for the warning message that should have been logged
assert any("Prompt caching" in record.message for record in caplog.records)

# Check that the Anthropic API was called without cache_control in messages so that it does not raise an error
_, kwargs = mock_chat_completion.call_args
for msg in kwargs["messages"]:
assert "cache_control" not in msg

@pytest.mark.parametrize("enable_caching", [True, False])
def test_run_with_prompt_caching(self, monkeypatch, mock_chat_completion, enable_caching):
"""
Test that the generation_kwargs extra_headers are correctly passed to the Anthropic API in both cases of
prompt caching being enabled or not
"""
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key")

generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if enable_caching else {}
component = AnthropicChatGenerator(generation_kwargs=generation_kwargs)

messages = [ChatMessage.from_system("System message"), ChatMessage.from_user("User message")]

component.run(messages)

# Check that the Anthropic API was called with the correct headers
_, kwargs = mock_chat_completion.call_args
headers = kwargs.get("extra_headers", {})
if enable_caching:
assert "anthropic-beta" in headers
else:
assert "anthropic-beta" not in headers

def test_to_dict_with_prompt_caching(self, monkeypatch):
"""
Test that the generation_kwargs extra_headers are correctly serialized to a dictionary
"""
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key")
component = AnthropicChatGenerator(
generation_kwargs={"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}}
)
data = component.to_dict()
assert (
data["init_parameters"]["generation_kwargs"]["extra_headers"]["anthropic-beta"]
== "prompt-caching-2024-07-31"
)

def test_from_dict_with_prompt_caching(self, monkeypatch):
"""
Test that the generation_kwargs extra_headers are correctly deserialized from a dictionary
"""
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key")
data = {
"type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"},
"model": "claude-3-5-sonnet-20240620",
"generation_kwargs": {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}},
},
}
component = AnthropicChatGenerator.from_dict(data)
assert component.generation_kwargs["extra_headers"]["anthropic-beta"] == "prompt-caching-2024-07-31"

def test_convert_messages_to_anthropic_format(self, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key")
generator = AnthropicChatGenerator()

# Test scenario 1: Regular user and assistant messages
messages = [
ChatMessage.from_user("Hello"),
ChatMessage.from_assistant("Hi there!"),
]
result = generator._convert_to_anthropic_format(messages)
assert result == [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]

# Test scenario 2: System message
messages = [ChatMessage.from_system("You are a helpful assistant.")]
result = generator._convert_to_anthropic_format(messages)
assert result == [{"type": "text", "text": "You are a helpful assistant."}]

# Test scenario 3: Mixed message types
messages = [
ChatMessage.from_system("Be concise."),
ChatMessage.from_user("What's AI?"),
ChatMessage.from_assistant("Artificial Intelligence."),
]
result = generator._convert_to_anthropic_format(messages)
assert result == [
{"type": "text", "text": "Be concise."},
{"role": "user", "content": "What's AI?"},
{"role": "assistant", "content": "Artificial Intelligence."},
]

# Test scenario 4: metadata
messages = [
ChatMessage.from_user("What's AI?"),
ChatMessage.from_assistant("Artificial Intelligence.", meta={"confidence": 0.9}),
]
result = generator._convert_to_anthropic_format(messages)
assert result == [
{"role": "user", "content": "What's AI?"},
{"role": "assistant", "content": "Artificial Intelligence.", "confidence": 0.9},
]

# Test scenario 5: Empty message list
assert generator._convert_to_anthropic_format([]) == []

@pytest.mark.integration
@pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY", None), reason="ANTHROPIC_API_KEY not set")
@pytest.mark.parametrize("cache_enabled", [True, False])
def test_prompt_caching(self, cache_enabled):
generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if cache_enabled else {}

claude_llm = AnthropicChatGenerator(
api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), generation_kwargs=generation_kwargs
)

# see https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#cache-limitations
system_message = ChatMessage.from_system("This is the cached, here we make it at least 1024 tokens long." * 70)
if cache_enabled:
system_message.meta["cache_control"] = {"type": "ephemeral"}

messages = [system_message, ChatMessage.from_user("What's in cached content?")]
result = claude_llm.run(messages)

assert "replies" in result
assert len(result["replies"]) == 1
token_usage = result["replies"][0].meta.get("usage")

if cache_enabled:
# either we created cache or we read it (depends on how you execute this integration test)
assert (
token_usage.get("cache_creation_input_tokens") > 1024
or token_usage.get("cache_read_input_tokens") > 1024
)
else:
assert "cache_creation_input_tokens" not in token_usage
assert "cache_read_input_tokens" not in token_usage

0 comments on commit 36f16c1

Please sign in to comment.