From 357e66d64da74a0fc66ce45f2f9fcc0f6fb3fc0d Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Fri, 20 Sep 2024 16:18:25 -0400 Subject: [PATCH] Input params for Together AI LLM --- src/pipecat/services/together.py | 46 ++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/src/pipecat/services/together.py b/src/pipecat/services/together.py index 004236ac8..6bc3980ce 100644 --- a/src/pipecat/services/together.py +++ b/src/pipecat/services/together.py @@ -13,6 +13,7 @@ from asyncio import CancelledError import re import uuid +from pydantic import BaseModel, Field from pipecat.frames.frames import ( Frame, @@ -58,18 +59,30 @@ def assistant(self) -> 'TogetherAssistantContextAggregator': class TogetherLLMService(LLMService): """This class implements inference with Together's Llama 3.1 models """ + class InputParams(BaseModel): + frequency_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0) + max_tokens: Optional[int] = Field(default=4096, ge=1) + presence_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0) + temperature: Optional[float] = Field(default=None, ge=0.0, le=1.0) + top_k: Optional[int] = Field(default=None, ge=0) + top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) def __init__( self, *, api_key: str, model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", - max_tokens: int = 4096, + params: InputParams = InputParams(), **kwargs): super().__init__(**kwargs) self._client = AsyncTogether(api_key=api_key) self.set_model_name(model) - self._max_tokens = max_tokens + self._max_tokens = params.max_tokens + self._frequency_penalty = params.frequency_penalty + self._presence_penalty = params.presence_penalty + self._temperature = params.temperature + self._top_k = params.top_k + self._top_p = params.top_p def can_generate_metrics(self) -> bool: return True @@ -83,6 +96,30 @@ def create_context_aggregator(context: OpenAILLMContext) -> TogetherContextAggre _assistant=assistant ) + async def set_frequency_penalty(self, frequency_penalty: float): + logger.debug(f"Switching LLM frequency_penalty to: [{frequency_penalty}]") + self._frequency_penalty = frequency_penalty + + async def set_max_tokens(self, max_tokens: int): + logger.debug(f"Switching LLM max_tokens to: [{max_tokens}]") + self._max_tokens = max_tokens + + async def set_presence_penalty(self, presence_penalty: float): + logger.debug(f"Switching LLM presence_penalty to: [{presence_penalty}]") + self._presence_penalty = presence_penalty + + async def set_temperature(self, temperature: float): + logger.debug(f"Switching LLM temperature to: [{temperature}]") + self._temperature = temperature + + async def set_top_k(self, top_k: float): + logger.debug(f"Switching LLM top_k to: [{top_k}]") + self._top_k = top_k + + async def set_top_p(self, top_p: float): + logger.debug(f"Switching LLM top_p to: [{top_p}]") + self._top_p = top_p + async def _process_context(self, context: OpenAILLMContext): try: await self.push_frame(LLMFullResponseStartFrame()) @@ -97,6 +134,11 @@ async def _process_context(self, context: OpenAILLMContext): model=self.model_name, max_tokens=self._max_tokens, stream=True, + frequency_penalty=self._frequency_penalty, + presence_penalty=self._presence_penalty, + temperature=self._temperature, + top_k=self._top_k, + top_p=self._top_p ) # Function calling