Skip to content

Commit

Permalink
Merge pull request #230 from pipecat-ai/aleix/processor-names
Browse files Browse the repository at this point in the history
processor names
  • Loading branch information
aconchillo authored Jun 12, 2024
2 parents 5eb1b90 + 0225443 commit 8d92cba
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 52 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Allow specifying frame processors' name through a new `name` constructor
argument.

### Changed

- `daily_rest.DailyRoomProperties` now allows extra unknown parameters.
Expand Down
13 changes: 1 addition & 12 deletions src/pipecat/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
# SPDX-License-Identifier: BSD 2-Clause License
#

from itertools import chain

from typing import Callable, Coroutine, List

from pipecat.frames.frames import Frame, MetricsFrame, StartFrame
from pipecat.frames.frames import Frame
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor

Expand Down Expand Up @@ -81,9 +79,6 @@ async def cleanup(self):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, StartFrame) and self.metrics_enabled:
await self._send_initial_metrics()

if direction == FrameDirection.DOWNSTREAM:
await self._source.process_frame(frame, FrameDirection.DOWNSTREAM)
elif direction == FrameDirection.UPSTREAM:
Expand All @@ -98,9 +93,3 @@ def _link_processors(self):
for curr in self._processors[1:]:
prev.link(curr)
prev = curr

async def _send_initial_metrics(self):
processors = self.processors_with_metrics()
ttfb = dict(zip([p.name for p in processors], [0] * len(processors)))
frame = MetricsFrame(ttfb=ttfb)
await self._source.process_frame(frame, FrameDirection.DOWNSTREAM)
11 changes: 9 additions & 2 deletions src/pipecat/pipeline/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from pydantic import BaseModel

from pipecat.frames.frames import CancelFrame, EndFrame, ErrorFrame, Frame, StartFrame, StopTaskFrame
from pipecat.frames.frames import CancelFrame, EndFrame, ErrorFrame, Frame, MetricsFrame, StartFrame, StopTaskFrame
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.utils.utils import obj_count, obj_id

Expand Down Expand Up @@ -40,7 +41,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):

class PipelineTask:

def __init__(self, pipeline: FrameProcessor, params: PipelineParams = PipelineParams()):
def __init__(self, pipeline: BasePipeline, params: PipelineParams = PipelineParams()):
self.id: int = obj_id()
self.name: str = f"{self.__class__.__name__}#{obj_count(self)}"

Expand Down Expand Up @@ -89,12 +90,18 @@ async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]):
else:
raise Exception("Frames must be an iterable or async iterable")

def _initial_metrics_frame(self) -> MetricsFrame:
processors = self._pipeline.processors_with_metrics()
ttfb = dict(zip([p.name for p in processors], [0] * len(processors)))
return MetricsFrame(ttfb=ttfb)

async def _process_down_queue(self):
start_frame = StartFrame(
allow_interruptions=self._params.allow_interruptions,
enable_metrics=self._params.enable_metrics,
)
await self._source.process_frame(start_frame, FrameDirection.DOWNSTREAM)
await self._source.process_frame(self._initial_metrics_frame(), FrameDirection.DOWNSTREAM)

running = True
should_cleanup = True
Expand Down
4 changes: 2 additions & 2 deletions src/pipecat/processors/frame_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ class FrameDirection(Enum):

class FrameProcessor:

def __init__(self, loop: asyncio.AbstractEventLoop | None = None):
def __init__(self, name: str | None = None, loop: asyncio.AbstractEventLoop | None = None):
self.id: int = obj_id()
self.name = f"{self.__class__.__name__}#{obj_count(self)}"
self.name = name or f"{self.__class__.__name__}#{obj_count(self)}"
self._prev: "FrameProcessor" | None = None
self._next: "FrameProcessor" | None = None
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_running_loop()
Expand Down
25 changes: 13 additions & 12 deletions src/pipecat/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@


class AIService(FrameProcessor):
def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)

async def start(self, frame: StartFrame):
pass
Expand Down Expand Up @@ -61,8 +61,8 @@ async def process_generator(self, generator: AsyncGenerator[Frame, None]):
class LLMService(AIService):
"""This class is a no-op but serves as a base class for LLM services."""

def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._callbacks = {}
self._start_callbacks = {}

Expand Down Expand Up @@ -91,8 +91,8 @@ async def call_start_function(self, function_name: str):


class TTSService(AIService):
def __init__(self, aggregate_sentences: bool = True):
super().__init__()
def __init__(self, aggregate_sentences: bool = True, **kwargs):
super().__init__(**kwargs)
self._aggregate_sentences: bool = aggregate_sentences
self._current_sentence: str = ""

Expand Down Expand Up @@ -146,8 +146,9 @@ def __init__(self,
max_silence_secs: float = 0.3,
max_buffer_secs: float = 1.5,
sample_rate: int = 16000,
num_channels: int = 1):
super().__init__()
num_channels: int = 1,
**kwargs):
super().__init__(**kwargs)
self._min_volume = min_volume
self._max_silence_secs = max_silence_secs
self._max_buffer_secs = max_buffer_secs
Expand Down Expand Up @@ -216,8 +217,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):

class ImageGenService(AIService):

def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)

# Renders the image. Returns an Image object.
@abstractmethod
Expand All @@ -237,8 +238,8 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
class VisionService(AIService):
"""VisionService is a base class for vision services."""

def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._describe_text = None

@abstractmethod
Expand Down
5 changes: 2 additions & 3 deletions src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
try:
from openai import AsyncOpenAI, AsyncStream, BadRequestError
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionFunctionMessageParam,
ChatCompletionMessageParam,
Expand All @@ -68,8 +67,8 @@ class BaseOpenAILLMService(LLMService):
calls from the LLM.
"""

def __init__(self, model: str, api_key=None, base_url=None):
super().__init__()
def __init__(self, model: str, api_key=None, base_url=None, **kwargs):
super().__init__(**kwargs)
self._model: str = model
self._client = self.create_client(api_key=api_key, base_url=base_url)

Expand Down
4 changes: 2 additions & 2 deletions src/pipecat/transports/base_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

class BaseInputTransport(FrameProcessor):

def __init__(self, params: TransportParams):
super().__init__()
def __init__(self, params: TransportParams, **kwargs):
super().__init__(**kwargs)

self._params = params

Expand Down
9 changes: 5 additions & 4 deletions src/pipecat/transports/base_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@

class BaseOutputTransport(FrameProcessor):

def __init__(self, params: TransportParams):
super().__init__()
def __init__(self, params: TransportParams, **kwargs):
super().__init__(**kwargs)

self._params = params

Expand Down Expand Up @@ -135,6 +135,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
elif isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame):
await self._handle_interruptions(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, MetricsFrame):
self.send_metrics(frame)
await self.push_frame(frame, direction)
elif isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
elif isinstance(frame, AudioRawFrame):
Expand Down Expand Up @@ -182,8 +185,6 @@ def _sink_thread_handler(self):
self._set_camera_images(frame.images)
elif isinstance(frame, TransportMessageFrame):
self.send_message(frame)
elif isinstance(frame, MetricsFrame):
self.send_metrics(frame)
else:
future = asyncio.run_coroutine_threadsafe(
self._internal_push_frame(frame), self.get_event_loop())
Expand Down
7 changes: 6 additions & 1 deletion src/pipecat/transports/base_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ class TransportParams(BaseModel):

class BaseTransport(ABC):

def __init__(self, loop: asyncio.AbstractEventLoop | None):
def __init__(self,
input_name: str | None = None,
output_name: str | None = None,
loop: asyncio.AbstractEventLoop | None = None):
self._input_name = input_name
self._output_name = output_name
self._loop = loop or asyncio.get_running_loop()
self._event_handlers: dict = {}

Expand Down
17 changes: 10 additions & 7 deletions src/pipecat/transports/network/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def __init__(
host: str,
port: int,
params: WebsocketServerParams,
callbacks: WebsocketServerCallbacks):
super().__init__(params)
callbacks: WebsocketServerCallbacks,
**kwargs):
super().__init__(params, **kwargs)

self._host = host
self._port = port
Expand Down Expand Up @@ -98,8 +99,8 @@ async def _client_handler(self, websocket: websockets.WebSocketServerProtocol, p

class WebsocketServerOutputTransport(BaseOutputTransport):

def __init__(self, params: WebsocketServerParams):
super().__init__(params)
def __init__(self, params: WebsocketServerParams, **kwargs):
super().__init__(params, **kwargs)

self._params = params

Expand Down Expand Up @@ -153,8 +154,10 @@ def __init__(
host: str = "localhost",
port: int = 8765,
params: WebsocketServerParams = WebsocketServerParams(),
input_name: str | None = None,
output_name: str | None = None,
loop: asyncio.AbstractEventLoop | None = None):
super().__init__(loop)
super().__init__(input_name=input_name, output_name=output_name, loop=loop)
self._host = host
self._port = port
self._params = params
Expand All @@ -175,12 +178,12 @@ def __init__(
def input(self) -> FrameProcessor:
if not self._input:
self._input = WebsocketServerInputTransport(
self._host, self._port, self._params, self._callbacks)
self._host, self._port, self._params, self._callbacks, name=self._input_name)
return self._input

def output(self) -> FrameProcessor:
if not self._output:
self._output = WebsocketServerOutputTransport(self._params)
self._output = WebsocketServerOutputTransport(self._params, name=self._output_name)
return self._output

async def _on_client_connected(self, websocket):
Expand Down
16 changes: 9 additions & 7 deletions src/pipecat/transports/services/daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,8 @@ def _video_frame_received(self, participant_id, video_frame):

class DailyInputTransport(BaseInputTransport):

def __init__(self, client: DailyTransportClient, params: DailyParams):
super().__init__(params)
def __init__(self, client: DailyTransportClient, params: DailyParams, **kwargs):
super().__init__(params, **kwargs)

self._client = client

Expand Down Expand Up @@ -609,8 +609,8 @@ def _on_participant_video_frame(self, participant_id: str, buffer, size, format)

class DailyOutputTransport(BaseOutputTransport):

def __init__(self, client: DailyTransportClient, params: DailyParams):
super().__init__(params)
def __init__(self, client: DailyTransportClient, params: DailyParams, **kwargs):
super().__init__(params, **kwargs)

self._client = client

Expand Down Expand Up @@ -662,8 +662,10 @@ def __init__(
token: str | None,
bot_name: str,
params: DailyParams,
input_name: str | None = None,
output_name: str | None = None,
loop: asyncio.AbstractEventLoop | None = None):
super().__init__(loop)
super().__init__(input_name=input_name, output_name=output_name, loop=loop)

callbacks = DailyCallbacks(
on_joined=self._on_joined,
Expand Down Expand Up @@ -708,12 +710,12 @@ def __init__(

def input(self) -> FrameProcessor:
if not self._input:
self._input = DailyInputTransport(self._client, self._params)
self._input = DailyInputTransport(self._client, self._params, name=self._input_name)
return self._input

def output(self) -> FrameProcessor:
if not self._output:
self._output = DailyOutputTransport(self._client, self._params)
self._output = DailyOutputTransport(self._client, self._params, name=self._output_name)
return self._output

#
Expand Down

0 comments on commit 8d92cba

Please sign in to comment.