Skip to content

Commit

Permalink
okay, both files now
Browse files Browse the repository at this point in the history
  • Loading branch information
kwindla committed Aug 15, 2024
1 parent 6e0dd4a commit 94deec0
Showing 1 changed file with 58 additions and 11 deletions.
69 changes: 58 additions & 11 deletions src/pipecat/services/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,21 @@ def __init__(
api_key: str,
model: str = "claude-3-5-sonnet-20240620",
max_tokens: int = 4096,
enable_prompt_caching_beta: bool = False,
**kwargs):
super().__init__(**kwargs)
self._client = AsyncAnthropic(api_key=api_key)
self._model = model
self._max_tokens = max_tokens
self._enable_prompt_caching_beta = enable_prompt_caching_beta

def can_generate_metrics(self) -> bool:
return True

@property
def enable_prompt_caching_beta(self) -> bool:
return self._enable_prompt_caching_beta

@staticmethod
def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair:
user = AnthropicUserContextAggregator(context)
Expand All @@ -98,6 +104,17 @@ def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggr
)

async def _process_context(self, context: OpenAILLMContext):
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
# completion_tokens. We also estimate the completion tokens from output text
# and use that estimate if we are interrupted, because we almost certainly won't
# get a complete usage report if the task we're running in is cancelled.
prompt_tokens = 0
completion_tokens = 0
completion_tokens_estimate = 0
use_completion_tokens_estimate = False
cache_creation_input_tokens = 0
cache_read_input_tokens = 0

try:
await self.push_frame(LLMFullResponseStartFrame())
await self.start_processing_metrics()
Expand All @@ -106,13 +123,19 @@ async def _process_context(self, context: OpenAILLMContext):
f"Generating chat: {context.system} | {context.get_messages_for_logging()}")

messages = context.messages
if self._enable_prompt_caching_beta:
messages = context.get_messages_with_cache_control_markers()

api_call = self._client.messages.create
if self._enable_prompt_caching_beta:
api_call = self._client.beta.prompt_caching.messages.create

await self.start_ttfb_metrics()

response = await self._client.messages.create(
response = await api_call(
tools=context.tools or [],
system=context.system or [],
messages=messages,
tools=context.tools or [],
model=self._model,
max_tokens=self._max_tokens,
stream=True)
Expand All @@ -123,15 +146,6 @@ async def _process_context(self, context: OpenAILLMContext):
tool_use_block = None
json_accumulator = ''

# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
# completion_tokens. We also estimate the completion tokens from output text
# and use that estimate if we are interrupted, because we almost certainly won't
# get a complete usage report if the task we're running in is cancelled.
prompt_tokens = 0
completion_tokens = 0
completion_tokens_estimate = 0
use_completion_tokens_estimate = False

async for event in response:
# logger.debug(f"Anthropic LLM event: {event}")

Expand Down Expand Up @@ -170,6 +184,15 @@ async def _process_context(self, context: OpenAILLMContext):
event.message.usage, "input_tokens") else 0
completion_tokens += event.message.usage.output_tokens if hasattr(
event.message.usage, "output_tokens") else 0
if hasattr(event.message.usage, "cache_creation_input_tokens"):
cache_creation_input_tokens += event.message.usage.cache_creation_input_tokens
logger.debug(f"Cache creation input tokens: {cache_creation_input_tokens}")
if hasattr(event.message.usage, "cache_read_input_tokens"):
cache_read_input_tokens += event.message.usage.cache_read_input_tokens
logger.debug(f"Cache read input tokens: {cache_read_input_tokens}")
total_input_tokens = prompt_tokens + cache_creation_input_tokens + cache_read_input_tokens
if total_input_tokens >= 1024:
context.turns_above_cache_threshold += 1

except CancelledError:
# If we're interrupted, we won't get a complete usage report. So set our flag to use the
Expand Down Expand Up @@ -241,6 +264,12 @@ def __init__(
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
self._user_image_request_context = {}

# For beta prompt caching. This is a counter that tracks the number of turns
# we've seen above the cache threshold. We reset this when we reset the
# messages list. We only care about this number being 0, 1, or 2. But
# it's easiest just to treat it as a counter.
self.turns_above_cache_threshold = 0

self.system = system

@classmethod
Expand Down Expand Up @@ -270,6 +299,7 @@ def from_image_frame(cls, frame: VisionImageRawFrame) -> "AnthropicLLMContext":
return context

def set_messages(self, messages: List):
self.turns_above_cache_threshold = 0
self._messages[:] = messages
self._restructure_from_openai_messages()

Expand Down Expand Up @@ -313,6 +343,23 @@ def add_message(self, message):
except Exception as e:
logger.error(f"Error adding message: {e}")

def get_messages_with_cache_control_markers(self) -> List[dict]:
try:
messages = copy.deepcopy(self.messages)
if self.turns_above_cache_threshold >= 1 and messages[-1]["role"] == "user":
if isinstance(messages[-1]["content"], str):
messages[-1]["content"] = [{"type": "text", "text": messages[-1]["content"]}]
messages[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}
if (self.turns_above_cache_threshold >= 2 and
len(messages) > 2 and messages[-3]["role"] == "user"):
if isinstance(messages[-3]["content"], str):
messages[-3]["content"] = [{"type": "text", "text": messages[-3]["content"]}]
messages[-3]["content"][-1]["cache_control"] = {"type": "ephemeral"}
return messages
except Exception as e:
logger.error(f"Error adding cache control marker: {e}")
return self.messages

def _restructure_from_openai_messages(self):
# See if we should pull the system message out of our context.messages list. (For
# compatibility with Open AI messages format.)
Expand Down

0 comments on commit 94deec0

Please sign in to comment.