Skip to content

Commit

Permalink
pipeline(parallel): wait for slowest endframe
Browse files Browse the repository at this point in the history
If we are sending an EndFrame and a ParallelPipeline has multiple pipelines we
want to wait before pushing the EndFrame downstream until the slowest pipeline
is finished. Otherwise, we could be disconnecting from the transport too early.
  • Loading branch information
aconchillo committed Dec 18, 2024
1 parent fb9f72d commit 2c46a07
Showing 1 changed file with 81 additions and 31 deletions.
112 changes: 81 additions & 31 deletions src/pipecat/pipeline/parallel_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@

import asyncio
from itertools import chain
from typing import Awaitable, Callable, List
from typing import Awaitable, Callable, Dict, List

from loguru import logger

from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame, SystemFrame
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
StartFrame,
StartInterruptionFrame,
SystemFrame,
)
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
Expand Down Expand Up @@ -72,11 +79,10 @@ def __init__(self, *args):
self._sources = []
self._sinks = []
self._seen_ids = set()
self._endframe_counter: Dict[int, int] = {}

self._up_queue = asyncio.Queue()
self._down_queue = asyncio.Queue()
self._up_task: asyncio.Task | None = None
self._down_task: asyncio.Task | None = None

self._pipelines = []

Expand Down Expand Up @@ -111,18 +117,19 @@ def processors_with_metrics(self) -> List[FrameProcessor]:
#

async def cleanup(self):
await asyncio.gather(*[s.cleanup() for s in self._sources])
await asyncio.gather(*[p.cleanup() for p in self._pipelines])

async def _start_tasks(self):
loop = self.get_event_loop()
self._up_task = loop.create_task(self._process_up_queue())
self._down_task = loop.create_task(self._process_down_queue())
await asyncio.gather(*[s.cleanup() for s in self._sinks])

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, StartFrame):
await self._start_tasks()
await self._start()
elif isinstance(frame, EndFrame):
self._endframe_counter[frame.id] = len(self._pipelines)
elif isinstance(frame, CancelFrame):
await self._cancel()

if direction == FrameDirection.UPSTREAM:
# If we get an upstream frame we process it in each sink.
Expand All @@ -131,36 +138,79 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
# If we get a downstream frame we process it in each source.
await asyncio.gather(*[s.queue_frame(frame, direction) for s in self._sources])

# If we get an EndFrame we stop our queue processing tasks and wait on
# all the pipelines to finish.
if isinstance(frame, (CancelFrame, EndFrame)):
# Use None to indicate when queues should be done processing.
await self._up_queue.put(None)
await self._down_queue.put(None)
if self._up_task:
await self._up_task
if self._down_task:
await self._down_task
# Handle interruptions after everything has been cancelled.
if isinstance(frame, StartInterruptionFrame):
await self._handle_interruption()
# Wait for tasks to finish.
elif isinstance(frame, EndFrame):
await self._stop()

async def _start(self):
await self._create_tasks()

async def _stop(self):
# The up task doesn't receive an EndFrame, so we just cancel it.
self._up_task.cancel()
await self._up_task
# The down tasks waits for the last EndFrame send by the internal
# pipelines.
await self._down_task

async def _cancel(self):
self._up_task.cancel()
await self._up_task
self._down_task.cancel()
await self._down_task

async def _create_tasks(self):
loop = self.get_event_loop()
self._up_task = loop.create_task(self._process_up_queue())
self._down_task = loop.create_task(self._process_down_queue())

async def _drain_queues(self):
while not self._up_queue.empty:
await self._up_queue.get()
while not self._down_queue.empty:
await self._down_queue.get()

async def _handle_interruption(self):
await self._cancel()
await self._drain_queues()
await self._create_tasks()

async def _parallel_push_frame(self, frame: Frame, direction: FrameDirection):
if frame.id not in self._seen_ids:
self._seen_ids.add(frame.id)
await self.push_frame(frame, direction)

async def _process_up_queue(self):
running = True
while running:
frame = await self._up_queue.get()
if frame:
while True:
try:
frame = await self._up_queue.get()
await self._parallel_push_frame(frame, FrameDirection.UPSTREAM)
running = frame is not None
self._up_queue.task_done()
self._up_queue.task_done()
except asyncio.CancelledError:
break

async def _process_down_queue(self):
running = True
while running:
frame = await self._down_queue.get()
if frame:
await self._parallel_push_frame(frame, FrameDirection.DOWNSTREAM)
running = frame is not None
self._down_queue.task_done()
try:
frame = await self._down_queue.get()

endframe_counter = self._endframe_counter.get(frame.id, 0)

# If we have a counter, decrement it.
if endframe_counter > 0:
self._endframe_counter[frame.id] -= 1
endframe_counter = self._endframe_counter[frame.id]

# If we don't have a counter or we reached 0, push the frame.
if endframe_counter == 0:
await self._parallel_push_frame(frame, FrameDirection.DOWNSTREAM)

running = not (endframe_counter == 0 and isinstance(frame, EndFrame))

self._down_queue.task_done()
except asyncio.CancelledError:
break

0 comments on commit 2c46a07

Please sign in to comment.