Skip to content

Commit

Permalink
pipeline(task): since everything is async tasks should wait for EndFrame
Browse files Browse the repository at this point in the history
  • Loading branch information
aconchillo committed Sep 30, 2024
1 parent e115a27 commit f64902e
Showing 1 changed file with 34 additions and 9 deletions.
43 changes: 34 additions & 9 deletions src/pipecat/pipeline/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,19 @@ async def _handle_upstream_frame(self, frame: Frame):
await self._up_queue.put(StopTaskFrame())


class Sink(FrameProcessor):
def __init__(self, down_queue: asyncio.Queue):
super().__init__()
self._down_queue = down_queue

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

# We really just want to know when the EndFrame reached the sink.
if isinstance(frame, EndFrame):
await self._down_queue.put(frame)


class PipelineTask:
def __init__(
self,
Expand All @@ -84,12 +97,16 @@ def __init__(
self._params = params
self._finished = False

self._down_queue = asyncio.Queue()
self._up_queue = asyncio.Queue()
self._down_queue = asyncio.Queue()
self._push_queue = asyncio.Queue()

self._source = Source(self._up_queue)
self._source.link(pipeline)

self._sink = Sink(self._down_queue)
pipeline.link(self._sink)

def has_finished(self):
return self._finished

Expand All @@ -103,19 +120,19 @@ async def cancel(self):
# out-of-band from the main streaming task which is what we want since
# we want to cancel right away.
await self._source.push_frame(CancelFrame())
self._process_down_task.cancel()
self._process_push_task.cancel()
self._process_up_task.cancel()
await self._process_down_task
await self._process_push_task
await self._process_up_task

async def run(self):
self._process_up_task = asyncio.create_task(self._process_up_queue())
self._process_down_task = asyncio.create_task(self._process_down_queue())
await asyncio.gather(self._process_up_task, self._process_down_task)
self._process_push_task = asyncio.create_task(self._process_push_queue())
await asyncio.gather(self._process_up_task, self._process_push_task)
self._finished = True

async def queue_frame(self, frame: Frame):
await self._down_queue.put(frame)
await self._push_queue.put(frame)

async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]):
if isinstance(frames, AsyncIterable):
Expand All @@ -133,7 +150,7 @@ def _initial_metrics_frame(self) -> MetricsFrame:
data.append(ProcessingMetricsData(processor=p.name, value=0.0))
return MetricsFrame(data=data)

async def _process_down_queue(self):
async def _process_push_queue(self):
self._clock.start()

start_frame = StartFrame(
Expand All @@ -154,11 +171,13 @@ async def _process_down_queue(self):
should_cleanup = True
while running:
try:
frame = await self._down_queue.get()
frame = await self._push_queue.get()
await self._source.process_frame(frame, FrameDirection.DOWNSTREAM)
if isinstance(frame, EndFrame):
await self._wait_for_endframe()
running = not (isinstance(frame, StopTaskFrame) or isinstance(frame, EndFrame))
should_cleanup = not isinstance(frame, StopTaskFrame)
self._down_queue.task_done()
self._push_queue.task_done()
except asyncio.CancelledError:
break
# Cleanup only if we need to.
Expand All @@ -169,6 +188,12 @@ async def _process_down_queue(self):
self._process_up_task.cancel()
await self._process_up_task

async def _wait_for_endframe(self):
# NOTE(aleix): the Sink element just pushes EndFrames to the down queue,
# so just wait for it. In the future we might do something else here,
# but for now this is fine.
await self._down_queue.get()

async def _process_up_queue(self):
while True:
try:
Expand Down

0 comments on commit f64902e

Please sign in to comment.