Skip to content

Commit

Permalink
Input params for Anthropic LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
markbackman committed Sep 20, 2024
1 parent 7127be6 commit 40c7dcb
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion src/pipecat/services/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from PIL import Image
from asyncio import CancelledError
import re
from pydantic import BaseModel, Field

from pipecat.frames.frames import (
Frame,
Expand Down Expand Up @@ -74,6 +75,10 @@ def assistant(self) -> 'AnthropicAssistantContextAggregator':
class AnthropicLLMService(LLMService):
"""This class implements inference with Anthropic's AI models
"""
class InputParams(BaseModel):
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)

def __init__(
self,
Expand All @@ -82,12 +87,16 @@ def __init__(
model: str = "claude-3-5-sonnet-20240620",
max_tokens: int = 4096,
enable_prompt_caching_beta: bool = False,
params: InputParams = InputParams(),
**kwargs):
super().__init__(**kwargs)
self._client = AsyncAnthropic(api_key=api_key)
self.set_model_name(model)
self._max_tokens = max_tokens
self._enable_prompt_caching_beta = enable_prompt_caching_beta
self._temperature = params.temperature
self._top_k = params.top_k
self._top_p = params.top_p

def can_generate_metrics(self) -> bool:
return True
Expand All @@ -105,6 +114,18 @@ def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggr
_assistant=assistant
)

async def set_temperature(self, temperature: float):
logger.debug(f"Switching LLM temperature to: [{temperature}]")
self._temperature = temperature

async def set_top_k(self, top_k: float):
logger.debug(f"Switching LLM top_k to: [{top_k}]")
self._top_k = top_k

async def set_top_p(self, top_p: float):
logger.debug(f"Switching LLM top_p to: [{top_p}]")
self._top_p = top_p

async def _process_context(self, context: OpenAILLMContext):
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
# completion_tokens. We also estimate the completion tokens from output text
Expand Down Expand Up @@ -140,7 +161,10 @@ async def _process_context(self, context: OpenAILLMContext):
messages=messages,
model=self.model_name,
max_tokens=self._max_tokens,
stream=True)
stream=True,
temperature=self._temperature,
top_k=self._top_k,
top_p=self._top_p)

await self.stop_ttfb_metrics()

Expand Down

0 comments on commit 40c7dcb

Please sign in to comment.