diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index 7935691ce..ea1756f8c 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -13,6 +13,7 @@ from PIL import Image from asyncio import CancelledError import re +from pydantic import BaseModel, Field from pipecat.frames.frames import ( Frame, @@ -74,20 +75,28 @@ def assistant(self) -> 'AnthropicAssistantContextAggregator': class AnthropicLLMService(LLMService): """This class implements inference with Anthropic's AI models """ + class InputParams(BaseModel): + enable_prompt_caching_beta: Optional[bool] = False + max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1) + temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0) + top_k: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0) + top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0) def __init__( self, *, api_key: str, model: str = "claude-3-5-sonnet-20240620", - max_tokens: int = 4096, - enable_prompt_caching_beta: bool = False, + params: InputParams = InputParams(), **kwargs): super().__init__(**kwargs) self._client = AsyncAnthropic(api_key=api_key) self.set_model_name(model) - self._max_tokens = max_tokens - self._enable_prompt_caching_beta = enable_prompt_caching_beta + self._max_tokens = params.max_tokens + self._enable_prompt_caching_beta: bool = params.enable_prompt_caching_beta or False + self._temperature = params.temperature + self._top_k = params.top_k + self._top_p = params.top_p def can_generate_metrics(self) -> bool: return True @@ -105,6 +114,26 @@ def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggr _assistant=assistant ) + async def set_enable_prompt_caching_beta(self, enable_prompt_caching_beta: bool): + logger.debug(f"Switching LLM enable_prompt_caching_beta to: [{enable_prompt_caching_beta}]") + self._enable_prompt_caching_beta = enable_prompt_caching_beta + + async def set_max_tokens(self, max_tokens: int): + logger.debug(f"Switching LLM max_tokens to: [{max_tokens}]") + self._max_tokens = max_tokens + + async def set_temperature(self, temperature: float): + logger.debug(f"Switching LLM temperature to: [{temperature}]") + self._temperature = temperature + + async def set_top_k(self, top_k: float): + logger.debug(f"Switching LLM top_k to: [{top_k}]") + self._top_k = top_k + + async def set_top_p(self, top_p: float): + logger.debug(f"Switching LLM top_p to: [{top_p}]") + self._top_p = top_p + async def _process_context(self, context: OpenAILLMContext): # Usage tracking. We track the usage reported by Anthropic in prompt_tokens and # completion_tokens. We also estimate the completion tokens from output text @@ -140,7 +169,10 @@ async def _process_context(self, context: OpenAILLMContext): messages=messages, model=self.model_name, max_tokens=self._max_tokens, - stream=True) + stream=True, + temperature=self._temperature, + top_k=self._top_k, + top_p=self._top_p) await self.stop_ttfb_metrics()