diff --git a/examples/foundational/07l-interruptible-together.py b/examples/foundational/07l-interruptible-together.py index 41befb67f..d5afa6d0d 100644 --- a/examples/foundational/07l-interruptible-together.py +++ b/examples/foundational/07l-interruptible-together.py @@ -57,10 +57,12 @@ async def main(): model=os.getenv("TOGETHER_MODEL"), params=TogetherLLMService.InputParams( temperature=1.0, - frequency_penalty=2.0, - presence_penalty=0.0, top_p=0.9, - top_k=40 + top_k=40, + extra={ + "frequency_penalty": 2.0, + "presence_penalty": 0.0, + } ) ) diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic.py index ea1756f8c..421196e2c 100644 --- a/src/pipecat/services/anthropic.py +++ b/src/pipecat/services/anthropic.py @@ -8,7 +8,7 @@ import json import io import copy -from typing import List, Optional +from typing import Any, Dict, List, Optional from dataclasses import dataclass from PIL import Image from asyncio import CancelledError @@ -81,6 +81,7 @@ class InputParams(BaseModel): temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0) top_k: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0) top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0) + extra: Optional[Dict[str, Any]] = Field(default_factory=dict) def __init__( self, @@ -97,6 +98,7 @@ def __init__( self._temperature = params.temperature self._top_k = params.top_k self._top_p = params.top_p + self._extra = params.extra if isinstance(params.extra, dict) else {} def can_generate_metrics(self) -> bool: return True @@ -134,6 +136,10 @@ async def set_top_p(self, top_p: float): logger.debug(f"Switching LLM top_p to: [{top_p}]") self._top_p = top_p + async def set_extra(self, extra: Dict[str, Any]): + logger.debug(f"Switching LLM extra to: [{extra}]") + self._extra = extra + async def _process_context(self, context: OpenAILLMContext): # Usage tracking. We track the usage reported by Anthropic in prompt_tokens and # completion_tokens. We also estimate the completion tokens from output text @@ -163,16 +169,21 @@ async def _process_context(self, context: OpenAILLMContext): await self.start_ttfb_metrics() - response = await api_call( - tools=context.tools or [], - system=context.system, - messages=messages, - model=self.model_name, - max_tokens=self._max_tokens, - stream=True, - temperature=self._temperature, - top_k=self._top_k, - top_p=self._top_p) + params = { + "tools": context.tools or [], + "system": context.system, + "messages": messages, + "model": self.model_name, + "max_tokens": self._max_tokens, + "stream": True, + "temperature": self._temperature, + "top_k": self._top_k, + "top_p": self._top_p + } + + params.update(self._extra) + + response = await api_call(**params) await self.stop_ttfb_metrics() diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 274a14820..4203f8194 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -11,7 +11,7 @@ import httpx from dataclasses import dataclass -from typing import AsyncGenerator, Dict, List, Literal, Optional +from typing import Any, AsyncGenerator, Dict, List, Literal, Optional from pydantic import BaseModel, Field from loguru import logger @@ -90,6 +90,7 @@ class InputParams(BaseModel): seed: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0) temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=2.0) top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0) + extra: Optional[Dict[str, Any]] = Field(default_factory=dict) def __init__( self, @@ -107,6 +108,7 @@ def __init__( self._seed = params.seed self._temperature = params.temperature self._top_p = params.top_p + self._extra = params.extra if isinstance(params.extra, dict) else {} def create_client(self, api_key=None, base_url=None, **kwargs): return AsyncOpenAI( @@ -141,23 +143,32 @@ async def set_top_p(self, top_p: float): logger.debug(f"Switching LLM top_p to: [{top_p}]") self._top_p = top_p + async def set_extra(self, extra: Dict[str, Any]): + logger.debug(f"Switching LLM extra to: [{extra}]") + self._extra = extra + async def get_chat_completions( self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]) -> AsyncStream[ChatCompletionChunk]: - chunks = await self._client.chat.completions.create( - model=self.model_name, - stream=True, - messages=messages, - tools=context.tools, - tool_choice=context.tool_choice, - stream_options={"include_usage": True}, - frequency_penalty=self._frequency_penalty, - presence_penalty=self._presence_penalty, - seed=self._seed, - temperature=self._temperature, - top_p=self._top_p - ) + + params = { + "model": self.model_name, + "stream": True, + "messages": messages, + "tools": context.tools, + "tool_choice": context.tool_choice, + "stream_options": {"include_usage": True}, + "frequency_penalty": self._frequency_penalty, + "presence_penalty": self._presence_penalty, + "seed": self._seed, + "temperature": self._temperature, + "top_p": self._top_p, + } + + params.update(self._extra) + + chunks = await self._client.chat.completions.create(**params) return chunks async def _stream_chat_completions( diff --git a/src/pipecat/services/together.py b/src/pipecat/services/together.py index 4c8a5527d..ce8c62730 100644 --- a/src/pipecat/services/together.py +++ b/src/pipecat/services/together.py @@ -9,7 +9,7 @@ import uuid from pydantic import BaseModel, Field -from typing import List +from typing import Any, Dict, List, Optional from dataclasses import dataclass from asyncio import CancelledError @@ -64,6 +64,7 @@ class InputParams(BaseModel): temperature: Optional[float] = Field(default=None, ge=0.0, le=1.0) top_k: Optional[int] = Field(default=None, ge=0) top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) + extra: Optional[Dict[str, Any]] = Field(default_factory=dict) def __init__( self, @@ -81,6 +82,7 @@ def __init__( self._temperature = params.temperature self._top_k = params.top_k self._top_p = params.top_p + self._extra = params.extra if isinstance(params.extra, dict) else {} def can_generate_metrics(self) -> bool: return True @@ -118,6 +120,10 @@ async def set_top_p(self, top_p: float): logger.debug(f"Switching LLM top_p to: [{top_p}]") self._top_p = top_p + async def set_extra(self, extra: Dict[str, Any]): + logger.debug(f"Switching LLM extra to: [{extra}]") + self._extra = extra + async def _process_context(self, context: OpenAILLMContext): try: await self.push_frame(LLMFullResponseStartFrame()) @@ -127,17 +133,21 @@ async def _process_context(self, context: OpenAILLMContext): await self.start_ttfb_metrics() - stream = await self._client.chat.completions.create( - messages=context.messages, - model=self.model_name, - max_tokens=self._max_tokens, - stream=True, - frequency_penalty=self._frequency_penalty, - presence_penalty=self._presence_penalty, - temperature=self._temperature, - top_k=self._top_k, - top_p=self._top_p - ) + params = { + "messages": context.messages, + "model": self.model_name, + "max_tokens": self._max_tokens, + "stream": True, + "frequency_penalty": self._frequency_penalty, + "presence_penalty": self._presence_penalty, + "temperature": self._temperature, + "top_k": self._top_k, + "top_p": self._top_p + } + + params.update(self._extra) + + stream = await self._client.chat.completions.create(**params) # Function calling got_first_chunk = False