diff --git a/src/pipecat/pipeline/parallel_pipeline.py b/src/pipecat/pipeline/parallel_pipeline.py index 323f7ed24..b39258782 100644 --- a/src/pipecat/pipeline/parallel_pipeline.py +++ b/src/pipecat/pipeline/parallel_pipeline.py @@ -6,11 +6,18 @@ import asyncio from itertools import chain -from typing import Awaitable, Callable, List +from typing import Awaitable, Callable, Dict, List from loguru import logger -from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame, SystemFrame +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + Frame, + StartFrame, + StartInterruptionFrame, + SystemFrame, +) from pipecat.pipeline.base_pipeline import BasePipeline from pipecat.pipeline.pipeline import Pipeline from pipecat.processors.frame_processor import FrameDirection, FrameProcessor @@ -72,11 +79,10 @@ def __init__(self, *args): self._sources = [] self._sinks = [] self._seen_ids = set() + self._endframe_counter: Dict[int, int] = {} self._up_queue = asyncio.Queue() self._down_queue = asyncio.Queue() - self._up_task: asyncio.Task | None = None - self._down_task: asyncio.Task | None = None self._pipelines = [] @@ -111,18 +117,19 @@ def processors_with_metrics(self) -> List[FrameProcessor]: # async def cleanup(self): + await asyncio.gather(*[s.cleanup() for s in self._sources]) await asyncio.gather(*[p.cleanup() for p in self._pipelines]) - - async def _start_tasks(self): - loop = self.get_event_loop() - self._up_task = loop.create_task(self._process_up_queue()) - self._down_task = loop.create_task(self._process_down_queue()) + await asyncio.gather(*[s.cleanup() for s in self._sinks]) async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) if isinstance(frame, StartFrame): - await self._start_tasks() + await self._start() + elif isinstance(frame, EndFrame): + self._endframe_counter[frame.id] = len(self._pipelines) + elif isinstance(frame, CancelFrame): + await self._cancel() if direction == FrameDirection.UPSTREAM: # If we get an upstream frame we process it in each sink. @@ -131,16 +138,45 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # If we get a downstream frame we process it in each source. await asyncio.gather(*[s.queue_frame(frame, direction) for s in self._sources]) - # If we get an EndFrame we stop our queue processing tasks and wait on - # all the pipelines to finish. - if isinstance(frame, (CancelFrame, EndFrame)): - # Use None to indicate when queues should be done processing. - await self._up_queue.put(None) - await self._down_queue.put(None) - if self._up_task: - await self._up_task - if self._down_task: - await self._down_task + # Handle interruptions after everything has been cancelled. + if isinstance(frame, StartInterruptionFrame): + await self._handle_interruption() + # Wait for tasks to finish. + elif isinstance(frame, EndFrame): + await self._stop() + + async def _start(self): + await self._create_tasks() + + async def _stop(self): + # The up task doesn't receive an EndFrame, so we just cancel it. + self._up_task.cancel() + await self._up_task + # The down tasks waits for the last EndFrame send by the internal + # pipelines. + await self._down_task + + async def _cancel(self): + self._up_task.cancel() + await self._up_task + self._down_task.cancel() + await self._down_task + + async def _create_tasks(self): + loop = self.get_event_loop() + self._up_task = loop.create_task(self._process_up_queue()) + self._down_task = loop.create_task(self._process_down_queue()) + + async def _drain_queues(self): + while not self._up_queue.empty: + await self._up_queue.get() + while not self._down_queue.empty: + await self._down_queue.get() + + async def _handle_interruption(self): + await self._cancel() + await self._drain_queues() + await self._create_tasks() async def _parallel_push_frame(self, frame: Frame, direction: FrameDirection): if frame.id not in self._seen_ids: @@ -148,19 +184,33 @@ async def _parallel_push_frame(self, frame: Frame, direction: FrameDirection): await self.push_frame(frame, direction) async def _process_up_queue(self): - running = True - while running: - frame = await self._up_queue.get() - if frame: + while True: + try: + frame = await self._up_queue.get() await self._parallel_push_frame(frame, FrameDirection.UPSTREAM) - running = frame is not None - self._up_queue.task_done() + self._up_queue.task_done() + except asyncio.CancelledError: + break async def _process_down_queue(self): running = True while running: - frame = await self._down_queue.get() - if frame: - await self._parallel_push_frame(frame, FrameDirection.DOWNSTREAM) - running = frame is not None - self._down_queue.task_done() + try: + frame = await self._down_queue.get() + + endframe_counter = self._endframe_counter.get(frame.id, 0) + + # If we have a counter, decrement it. + if endframe_counter > 0: + self._endframe_counter[frame.id] -= 1 + endframe_counter = self._endframe_counter[frame.id] + + # If we don't have a counter or we reached 0, push the frame. + if endframe_counter == 0: + await self._parallel_push_frame(frame, FrameDirection.DOWNSTREAM) + + running = not (endframe_counter == 0 and isinstance(frame, EndFrame)) + + self._down_queue.task_done() + except asyncio.CancelledError: + break