Skip to content

Commit

Permalink
Merge pull request #484 from pipecat-ai/mb/llm-input-params
Browse files Browse the repository at this point in the history
Add input params for OpenAI, Anthropic, Together AI LLMs
  • Loading branch information
markbackman authored Sep 21, 2024
2 parents 14acf05 + 219304c commit e8f8a49
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 17 deletions.
14 changes: 10 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added configurable LLM parameters (e.g., temperature, top_p, max_tokens, seed)
for OpenAI, Anthropic, and Together AI services along with corresponding
setter functions.

- Added `sample_rate` as a constructor parameter for TTS services.

- Pipecat has a pipeline-based architecture. The pipeline consists of frame
processors linked to each other. The elements traveling across the pipeline
are called frames.
Expand Down Expand Up @@ -343,7 +349,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- It is now possible to specify a Silero VAD version when using `SileroVADAnalyzer`
or `SileroVAD`.

- Added `AysncFrameProcessor` and `AsyncAIService`. Some services like
- Added `AysncFrameProcessor` and `AsyncAIService`. Some services like
`DeepgramSTTService` need to process things asynchronously. For example, audio
is sent to Deepgram but transcriptions are not returned immediately. In these
cases we still require all frames (except system frames) to be pushed
Expand All @@ -360,7 +366,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `WhisperSTTService` model can now also be a string.

- Added missing * keyword separators in services.
- Added missing \* keyword separators in services.

### Fixed

Expand Down Expand Up @@ -437,7 +443,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new `TwilioFrameSerializer`. This is a new serializer that knows how to
serialize and deserialize audio frames from Twilio.

- Added Daily transport event: `on_dialout_answered`. See
- Added Daily transport event: `on_dialout_answered`. See
https://reference-python.daily.co/api_reference.html#daily.EventHandler

- Added new `AzureSTTService`. This allows you to use Azure Speech-To-Text.
Expand Down Expand Up @@ -677,7 +683,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Daily transport support for dial-in use cases.

- Added Daily transport events: `on_dialout_connected`, `on_dialout_stopped`,
`on_dialout_error` and `on_dialout_warning`. See
`on_dialout_error` and `on_dialout_warning`. See
https://reference-python.daily.co/api_reference.html#daily.EventHandler

## [0.0.21] - 2024-05-22
Expand Down
100 changes: 100 additions & 0 deletions examples/foundational/07l-interruptible-together.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import aiohttp
import os
import sys

from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_response import (
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.together import TogetherLLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
from pipecat.vad.silero import SileroVADAnalyzer

from runner import configure

from loguru import logger

from dotenv import load_dotenv
load_dotenv(override=True)

logger.remove(0)
logger.add(sys.stderr, level="DEBUG")


async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)

transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer()
)
)

tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)

llm = TogetherLLMService(
api_key=os.getenv("TOGETHER_API_KEY"),
model=os.getenv("TOGETHER_MODEL"),
params=TogetherLLMService.InputParams(
temperature=1.0,
frequency_penalty=2.0,
presence_penalty=0.0,
top_p=0.9,
top_k=40
)
)

messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]

tma_in = LLMUserResponseAggregator(messages)
tma_out = LLMAssistantResponseAggregator(messages)

pipeline = Pipeline([
transport.input(), # Transport user input
tma_in, # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
tma_out # Assistant spoken responses
])

task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))

@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
await task.queue_frames([LLMMessagesFrame(messages)])

runner = PipelineRunner()

await runner.run(task)


if __name__ == "__main__":
asyncio.run(main())
42 changes: 37 additions & 5 deletions src/pipecat/services/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from PIL import Image
from asyncio import CancelledError
import re
from pydantic import BaseModel, Field

from pipecat.frames.frames import (
Frame,
Expand Down Expand Up @@ -74,20 +75,28 @@ def assistant(self) -> 'AnthropicAssistantContextAggregator':
class AnthropicLLMService(LLMService):
"""This class implements inference with Anthropic's AI models
"""
class InputParams(BaseModel):
enable_prompt_caching_beta: Optional[bool] = False
max_tokens: Optional[int] = Field(default_factory=lambda: 4096, ge=1)
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)

def __init__(
self,
*,
api_key: str,
model: str = "claude-3-5-sonnet-20240620",
max_tokens: int = 4096,
enable_prompt_caching_beta: bool = False,
params: InputParams = InputParams(),
**kwargs):
super().__init__(**kwargs)
self._client = AsyncAnthropic(api_key=api_key)
self.set_model_name(model)
self._max_tokens = max_tokens
self._enable_prompt_caching_beta = enable_prompt_caching_beta
self._max_tokens = params.max_tokens
self._enable_prompt_caching_beta: bool = params.enable_prompt_caching_beta or False
self._temperature = params.temperature
self._top_k = params.top_k
self._top_p = params.top_p

def can_generate_metrics(self) -> bool:
return True
Expand All @@ -105,6 +114,26 @@ def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggr
_assistant=assistant
)

async def set_enable_prompt_caching_beta(self, enable_prompt_caching_beta: bool):
logger.debug(f"Switching LLM enable_prompt_caching_beta to: [{enable_prompt_caching_beta}]")
self._enable_prompt_caching_beta = enable_prompt_caching_beta

async def set_max_tokens(self, max_tokens: int):
logger.debug(f"Switching LLM max_tokens to: [{max_tokens}]")
self._max_tokens = max_tokens

async def set_temperature(self, temperature: float):
logger.debug(f"Switching LLM temperature to: [{temperature}]")
self._temperature = temperature

async def set_top_k(self, top_k: float):
logger.debug(f"Switching LLM top_k to: [{top_k}]")
self._top_k = top_k

async def set_top_p(self, top_p: float):
logger.debug(f"Switching LLM top_p to: [{top_p}]")
self._top_p = top_p

async def _process_context(self, context: OpenAILLMContext):
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
# completion_tokens. We also estimate the completion tokens from output text
Expand Down Expand Up @@ -140,7 +169,10 @@ async def _process_context(self, context: OpenAILLMContext):
messages=messages,
model=self.model_name,
max_tokens=self._max_tokens,
stream=True)
stream=True,
temperature=self._temperature,
top_k=self._top_k,
top_p=self._top_p)

await self.stop_ttfb_metrics()

Expand Down
63 changes: 57 additions & 6 deletions src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import httpx
from dataclasses import dataclass

from typing import AsyncGenerator, Dict, List, Literal
from typing import AsyncGenerator, Dict, List, Literal, Optional
from pydantic import BaseModel, Field

from loguru import logger
from PIL import Image
Expand Down Expand Up @@ -48,7 +49,7 @@
)

try:
from openai import AsyncOpenAI, AsyncStream, DefaultAsyncHttpxClient, BadRequestError
from openai import AsyncOpenAI, AsyncStream, DefaultAsyncHttpxClient, BadRequestError, NOT_GIVEN
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
Expand Down Expand Up @@ -81,11 +82,31 @@ class BaseOpenAILLMService(LLMService):
as well as tool choices and the tool, which is used if requesting function
calls from the LLM.
"""
class InputParams(BaseModel):
frequency_penalty: Optional[float] = Field(
default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0)
presence_penalty: Optional[float] = Field(
default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0)
seed: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)

def __init__(self, *, model: str, api_key=None, base_url=None, **kwargs):
def __init__(
self,
*,
model: str,
api_key=None,
base_url=None,
params: InputParams = InputParams(),
**kwargs):
super().__init__(**kwargs)
self.set_model_name(model)
self._client = self.create_client(api_key=api_key, base_url=base_url, **kwargs)
self._frequency_penalty = params.frequency_penalty
self._presence_penalty = params.presence_penalty
self._seed = params.seed
self._temperature = params.temperature
self._top_p = params.top_p

def create_client(self, api_key=None, base_url=None, **kwargs):
return AsyncOpenAI(
Expand All @@ -100,6 +121,26 @@ def create_client(self, api_key=None, base_url=None, **kwargs):
def can_generate_metrics(self) -> bool:
return True

async def set_frequency_penalty(self, frequency_penalty: float):
logger.debug(f"Switching LLM frequency_penalty to: [{frequency_penalty}]")
self._frequency_penalty = frequency_penalty

async def set_presence_penalty(self, presence_penalty: float):
logger.debug(f"Switching LLM presence_penalty to: [{presence_penalty}]")
self._presence_penalty = presence_penalty

async def set_seed(self, seed: int):
logger.debug(f"Switching LLM seed to: [{seed}]")
self._seed = seed

async def set_temperature(self, temperature: float):
logger.debug(f"Switching LLM temperature to: [{temperature}]")
self._temperature = temperature

async def set_top_p(self, top_p: float):
logger.debug(f"Switching LLM top_p to: [{top_p}]")
self._top_p = top_p

async def get_chat_completions(
self,
context: OpenAILLMContext,
Expand All @@ -110,7 +151,12 @@ async def get_chat_completions(
messages=messages,
tools=context.tools,
tool_choice=context.tool_choice,
stream_options={"include_usage": True}
stream_options={"include_usage": True},
frequency_penalty=self._frequency_penalty,
presence_penalty=self._presence_penalty,
seed=self._seed,
temperature=self._temperature,
top_p=self._top_p
)
return chunks

Expand Down Expand Up @@ -248,8 +294,13 @@ def assistant(self) -> 'OpenAIAssistantContextAggregator':

class OpenAILLMService(BaseOpenAILLMService):

def __init__(self, *, model: str = "gpt-4o", **kwargs):
super().__init__(model=model, **kwargs)
def __init__(
self,
*,
model: str = "gpt-4o",
params: BaseOpenAILLMService.InputParams = BaseOpenAILLMService.InputParams(),
**kwargs):
super().__init__(model=model, params=params, **kwargs)

@staticmethod
def create_context_aggregator(context: OpenAILLMContext) -> OpenAIContextAggregatorPair:
Expand Down
Loading

0 comments on commit e8f8a49

Please sign in to comment.