Skip to content

Commit

Permalink
Add extra input param to LLMs
Browse files Browse the repository at this point in the history
  • Loading branch information
markbackman committed Sep 21, 2024
1 parent 26a64af commit c73111a
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 40 deletions.
8 changes: 5 additions & 3 deletions examples/foundational/07l-interruptible-together.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ async def main():
model=os.getenv("TOGETHER_MODEL"),
params=TogetherLLMService.InputParams(
temperature=1.0,
frequency_penalty=2.0,
presence_penalty=0.0,
top_p=0.9,
top_k=40
top_k=40,
extra={
"frequency_penalty": 2.0,
"presence_penalty": 0.0,
}
)
)

Expand Down
33 changes: 22 additions & 11 deletions src/pipecat/services/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import io
import copy
from typing import List, Optional
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
from PIL import Image
from asyncio import CancelledError
Expand Down Expand Up @@ -81,6 +81,7 @@ class InputParams(BaseModel):
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)

def __init__(
self,
Expand All @@ -97,6 +98,7 @@ def __init__(
self._temperature = params.temperature
self._top_k = params.top_k
self._top_p = params.top_p
self._extra = params.extra if isinstance(params.extra, dict) else {}

def can_generate_metrics(self) -> bool:
return True
Expand Down Expand Up @@ -134,6 +136,10 @@ async def set_top_p(self, top_p: float):
logger.debug(f"Switching LLM top_p to: [{top_p}]")
self._top_p = top_p

async def set_extra(self, extra: Dict[str, Any]):
logger.debug(f"Switching LLM extra to: [{extra}]")
self._extra = extra

async def _process_context(self, context: OpenAILLMContext):
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
# completion_tokens. We also estimate the completion tokens from output text
Expand Down Expand Up @@ -163,16 +169,21 @@ async def _process_context(self, context: OpenAILLMContext):

await self.start_ttfb_metrics()

response = await api_call(
tools=context.tools or [],
system=context.system,
messages=messages,
model=self.model_name,
max_tokens=self._max_tokens,
stream=True,
temperature=self._temperature,
top_k=self._top_k,
top_p=self._top_p)
params = {
"tools": context.tools or [],
"system": context.system,
"messages": messages,
"model": self.model_name,
"max_tokens": self._max_tokens,
"stream": True,
"temperature": self._temperature,
"top_k": self._top_k,
"top_p": self._top_p
}

params.update(self._extra)

response = await api_call(**params)

await self.stop_ttfb_metrics()

Expand Down
39 changes: 25 additions & 14 deletions src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import httpx
from dataclasses import dataclass

from typing import AsyncGenerator, Dict, List, Literal, Optional
from typing import Any, AsyncGenerator, Dict, List, Literal, Optional
from pydantic import BaseModel, Field

from loguru import logger
Expand Down Expand Up @@ -90,6 +90,7 @@ class InputParams(BaseModel):
seed: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)

def __init__(
self,
Expand All @@ -107,6 +108,7 @@ def __init__(
self._seed = params.seed
self._temperature = params.temperature
self._top_p = params.top_p
self._extra = params.extra if isinstance(params.extra, dict) else {}

def create_client(self, api_key=None, base_url=None, **kwargs):
return AsyncOpenAI(
Expand Down Expand Up @@ -141,23 +143,32 @@ async def set_top_p(self, top_p: float):
logger.debug(f"Switching LLM top_p to: [{top_p}]")
self._top_p = top_p

async def set_extra(self, extra: Dict[str, Any]):
logger.debug(f"Switching LLM extra to: [{extra}]")
self._extra = extra

async def get_chat_completions(
self,
context: OpenAILLMContext,
messages: List[ChatCompletionMessageParam]) -> AsyncStream[ChatCompletionChunk]:
chunks = await self._client.chat.completions.create(
model=self.model_name,
stream=True,
messages=messages,
tools=context.tools,
tool_choice=context.tool_choice,
stream_options={"include_usage": True},
frequency_penalty=self._frequency_penalty,
presence_penalty=self._presence_penalty,
seed=self._seed,
temperature=self._temperature,
top_p=self._top_p
)

params = {
"model": self.model_name,
"stream": True,
"messages": messages,
"tools": context.tools,
"tool_choice": context.tool_choice,
"stream_options": {"include_usage": True},
"frequency_penalty": self._frequency_penalty,
"presence_penalty": self._presence_penalty,
"seed": self._seed,
"temperature": self._temperature,
"top_p": self._top_p,
}

params.update(self._extra)

chunks = await self._client.chat.completions.create(**params)
return chunks

async def _stream_chat_completions(
Expand Down
34 changes: 22 additions & 12 deletions src/pipecat/services/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import uuid
from pydantic import BaseModel, Field

from typing import List
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
from asyncio import CancelledError

Expand Down Expand Up @@ -64,6 +64,7 @@ class InputParams(BaseModel):
temperature: Optional[float] = Field(default=None, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=None, ge=0)
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)

def __init__(
self,
Expand All @@ -81,6 +82,7 @@ def __init__(
self._temperature = params.temperature
self._top_k = params.top_k
self._top_p = params.top_p
self._extra = params.extra if isinstance(params.extra, dict) else {}

def can_generate_metrics(self) -> bool:
return True
Expand Down Expand Up @@ -118,6 +120,10 @@ async def set_top_p(self, top_p: float):
logger.debug(f"Switching LLM top_p to: [{top_p}]")
self._top_p = top_p

async def set_extra(self, extra: Dict[str, Any]):
logger.debug(f"Switching LLM extra to: [{extra}]")
self._extra = extra

async def _process_context(self, context: OpenAILLMContext):
try:
await self.push_frame(LLMFullResponseStartFrame())
Expand All @@ -127,17 +133,21 @@ async def _process_context(self, context: OpenAILLMContext):

await self.start_ttfb_metrics()

stream = await self._client.chat.completions.create(
messages=context.messages,
model=self.model_name,
max_tokens=self._max_tokens,
stream=True,
frequency_penalty=self._frequency_penalty,
presence_penalty=self._presence_penalty,
temperature=self._temperature,
top_k=self._top_k,
top_p=self._top_p
)
params = {
"messages": context.messages,
"model": self.model_name,
"max_tokens": self._max_tokens,
"stream": True,
"frequency_penalty": self._frequency_penalty,
"presence_penalty": self._presence_penalty,
"temperature": self._temperature,
"top_k": self._top_k,
"top_p": self._top_p
}

params.update(self._extra)

stream = await self._client.chat.completions.create(**params)

# Function calling
got_first_chunk = False
Expand Down

0 comments on commit c73111a

Please sign in to comment.