From 3b81cd462d36c83b5d741aed253b56932a9ae06c Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Fri, 20 Sep 2024 13:41:04 -0400 Subject: [PATCH] Input params to OpenAI LLM --- src/pipecat/services/openai.py | 63 ++++++++++++++++++++++++++++++---- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 6fe710b5e..440281c90 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -11,7 +11,8 @@ import httpx from dataclasses import dataclass -from typing import AsyncGenerator, Dict, List, Literal +from typing import AsyncGenerator, Dict, List, Literal, Optional +from pydantic import BaseModel, Field from loguru import logger from PIL import Image @@ -48,7 +49,7 @@ ) try: - from openai import AsyncOpenAI, AsyncStream, DefaultAsyncHttpxClient, BadRequestError + from openai import AsyncOpenAI, AsyncStream, DefaultAsyncHttpxClient, BadRequestError, NOT_GIVEN from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam except ModuleNotFoundError as e: logger.error(f"Exception: {e}") @@ -81,11 +82,31 @@ class BaseOpenAILLMService(LLMService): as well as tool choices and the tool, which is used if requesting function calls from the LLM. """ + class InputParams(BaseModel): + frequency_penalty: Optional[float] = Field( + default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0) + presence_penalty: Optional[float] = Field( + default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0) + seed: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0) + temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=2.0) + top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0) - def __init__(self, *, model: str, api_key=None, base_url=None, **kwargs): + def __init__( + self, + *, + model: str, + api_key=None, + base_url=None, + params: InputParams = InputParams(), + **kwargs): super().__init__(**kwargs) self.set_model_name(model) self._client = self.create_client(api_key=api_key, base_url=base_url, **kwargs) + self._frequency_penalty = params.frequency_penalty + self._presence_penalty = params.presence_penalty + self._seed = params.seed + self._temperature = params.temperature + self._top_p = params.top_p def create_client(self, api_key=None, base_url=None, **kwargs): return AsyncOpenAI( @@ -100,6 +121,26 @@ def create_client(self, api_key=None, base_url=None, **kwargs): def can_generate_metrics(self) -> bool: return True + async def set_frequency_penalty(self, frequency_penalty: float): + logger.debug(f"Switching LLM frequency_penalty to: [{frequency_penalty}]") + self._frequency_penalty = frequency_penalty + + async def set_presence_penalty(self, presence_penalty: float): + logger.debug(f"Switching LLM presence_penalty to: [{presence_penalty}]") + self._presence_penalty = presence_penalty + + async def set_seed(self, seed: int): + logger.debug(f"Switching LLM seed to: [{seed}]") + self._seed = seed + + async def set_temperature(self, temperature: float): + logger.debug(f"Switching LLM temperature to: [{temperature}]") + self._temperature = temperature + + async def set_top_p(self, top_p: float): + logger.debug(f"Switching LLM top_p to: [{top_p}]") + self._top_p = top_p + async def get_chat_completions( self, context: OpenAILLMContext, @@ -110,7 +151,12 @@ async def get_chat_completions( messages=messages, tools=context.tools, tool_choice=context.tool_choice, - stream_options={"include_usage": True} + stream_options={"include_usage": True}, + frequency_penalty=self._frequency_penalty, + presence_penalty=self._presence_penalty, + seed=self._seed, + temperature=self._temperature, + top_p=self._top_p ) return chunks @@ -248,8 +294,13 @@ def assistant(self) -> 'OpenAIAssistantContextAggregator': class OpenAILLMService(BaseOpenAILLMService): - def __init__(self, *, model: str = "gpt-4o", **kwargs): - super().__init__(model=model, **kwargs) + def __init__( + self, + *, + model: str = "gpt-4o", + params: BaseOpenAILLMService.InputParams = BaseOpenAILLMService.InputParams(), + **kwargs): + super().__init__(model=model, params=params, **kwargs) @staticmethod def create_context_aggregator(context: OpenAILLMContext) -> OpenAIContextAggregatorPair: