Skip to content

Commit

Permalink
Input params to OpenAI LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
markbackman committed Sep 21, 2024
1 parent 58d9c84 commit 3b81cd4
Showing 1 changed file with 57 additions and 6 deletions.
63 changes: 57 additions & 6 deletions src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import httpx
from dataclasses import dataclass

from typing import AsyncGenerator, Dict, List, Literal
from typing import AsyncGenerator, Dict, List, Literal, Optional
from pydantic import BaseModel, Field

from loguru import logger
from PIL import Image
Expand Down Expand Up @@ -48,7 +49,7 @@
)

try:
from openai import AsyncOpenAI, AsyncStream, DefaultAsyncHttpxClient, BadRequestError
from openai import AsyncOpenAI, AsyncStream, DefaultAsyncHttpxClient, BadRequestError, NOT_GIVEN
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
Expand Down Expand Up @@ -81,11 +82,31 @@ class BaseOpenAILLMService(LLMService):
as well as tool choices and the tool, which is used if requesting function
calls from the LLM.
"""
class InputParams(BaseModel):
frequency_penalty: Optional[float] = Field(
default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0)
presence_penalty: Optional[float] = Field(
default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0)
seed: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)

def __init__(self, *, model: str, api_key=None, base_url=None, **kwargs):
def __init__(
self,
*,
model: str,
api_key=None,
base_url=None,
params: InputParams = InputParams(),
**kwargs):
super().__init__(**kwargs)
self.set_model_name(model)
self._client = self.create_client(api_key=api_key, base_url=base_url, **kwargs)
self._frequency_penalty = params.frequency_penalty
self._presence_penalty = params.presence_penalty
self._seed = params.seed
self._temperature = params.temperature
self._top_p = params.top_p

def create_client(self, api_key=None, base_url=None, **kwargs):
return AsyncOpenAI(
Expand All @@ -100,6 +121,26 @@ def create_client(self, api_key=None, base_url=None, **kwargs):
def can_generate_metrics(self) -> bool:
return True

async def set_frequency_penalty(self, frequency_penalty: float):
logger.debug(f"Switching LLM frequency_penalty to: [{frequency_penalty}]")
self._frequency_penalty = frequency_penalty

async def set_presence_penalty(self, presence_penalty: float):
logger.debug(f"Switching LLM presence_penalty to: [{presence_penalty}]")
self._presence_penalty = presence_penalty

async def set_seed(self, seed: int):
logger.debug(f"Switching LLM seed to: [{seed}]")
self._seed = seed

async def set_temperature(self, temperature: float):
logger.debug(f"Switching LLM temperature to: [{temperature}]")
self._temperature = temperature

async def set_top_p(self, top_p: float):
logger.debug(f"Switching LLM top_p to: [{top_p}]")
self._top_p = top_p

async def get_chat_completions(
self,
context: OpenAILLMContext,
Expand All @@ -110,7 +151,12 @@ async def get_chat_completions(
messages=messages,
tools=context.tools,
tool_choice=context.tool_choice,
stream_options={"include_usage": True}
stream_options={"include_usage": True},
frequency_penalty=self._frequency_penalty,
presence_penalty=self._presence_penalty,
seed=self._seed,
temperature=self._temperature,
top_p=self._top_p
)
return chunks

Expand Down Expand Up @@ -248,8 +294,13 @@ def assistant(self) -> 'OpenAIAssistantContextAggregator':

class OpenAILLMService(BaseOpenAILLMService):

def __init__(self, *, model: str = "gpt-4o", **kwargs):
super().__init__(model=model, **kwargs)
def __init__(
self,
*,
model: str = "gpt-4o",
params: BaseOpenAILLMService.InputParams = BaseOpenAILLMService.InputParams(),
**kwargs):
super().__init__(model=model, params=params, **kwargs)

@staticmethod
def create_context_aggregator(context: OpenAILLMContext) -> OpenAIContextAggregatorPair:
Expand Down

0 comments on commit 3b81cd4

Please sign in to comment.