diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index f79ff6f39..96845430d 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -69,6 +69,19 @@ async def _handle_upstream_frame(self, frame: Frame): await self._up_queue.put(StopTaskFrame()) +class Sink(FrameProcessor): + def __init__(self, down_queue: asyncio.Queue): + super().__init__() + self._down_queue = down_queue + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + # We really just want to know when the EndFrame reached the sink. + if isinstance(frame, EndFrame): + await self._down_queue.put(frame) + + class PipelineTask: def __init__( self, @@ -84,12 +97,16 @@ def __init__( self._params = params self._finished = False - self._down_queue = asyncio.Queue() self._up_queue = asyncio.Queue() + self._down_queue = asyncio.Queue() + self._push_queue = asyncio.Queue() self._source = Source(self._up_queue) self._source.link(pipeline) + self._sink = Sink(self._down_queue) + pipeline.link(self._sink) + def has_finished(self): return self._finished @@ -103,19 +120,19 @@ async def cancel(self): # out-of-band from the main streaming task which is what we want since # we want to cancel right away. await self._source.push_frame(CancelFrame()) - self._process_down_task.cancel() + self._process_push_task.cancel() self._process_up_task.cancel() - await self._process_down_task + await self._process_push_task await self._process_up_task async def run(self): self._process_up_task = asyncio.create_task(self._process_up_queue()) - self._process_down_task = asyncio.create_task(self._process_down_queue()) - await asyncio.gather(self._process_up_task, self._process_down_task) + self._process_push_task = asyncio.create_task(self._process_push_queue()) + await asyncio.gather(self._process_up_task, self._process_push_task) self._finished = True async def queue_frame(self, frame: Frame): - await self._down_queue.put(frame) + await self._push_queue.put(frame) async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]): if isinstance(frames, AsyncIterable): @@ -133,7 +150,7 @@ def _initial_metrics_frame(self) -> MetricsFrame: data.append(ProcessingMetricsData(processor=p.name, value=0.0)) return MetricsFrame(data=data) - async def _process_down_queue(self): + async def _process_push_queue(self): self._clock.start() start_frame = StartFrame( @@ -154,11 +171,13 @@ async def _process_down_queue(self): should_cleanup = True while running: try: - frame = await self._down_queue.get() + frame = await self._push_queue.get() await self._source.process_frame(frame, FrameDirection.DOWNSTREAM) + if isinstance(frame, EndFrame): + await self._wait_for_endframe() running = not (isinstance(frame, StopTaskFrame) or isinstance(frame, EndFrame)) should_cleanup = not isinstance(frame, StopTaskFrame) - self._down_queue.task_done() + self._push_queue.task_done() except asyncio.CancelledError: break # Cleanup only if we need to. @@ -169,6 +188,12 @@ async def _process_down_queue(self): self._process_up_task.cancel() await self._process_up_task + async def _wait_for_endframe(self): + # NOTE(aleix): the Sink element just pushes EndFrames to the down queue, + # so just wait for it. In the future we might do something else here, + # but for now this is fine. + await self._down_queue.get() + async def _process_up_queue(self): while True: try: