Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pipeline(parallel): wait for slowest endframe #885

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed an issue that would cause `ParallelPipeline` to handle `EndFrame`
incorrectly causing the main pipeline to not terminate or terminate too early.

- Fixed an audio stuttering issue in `FastPitchTTSService`.

- Fixed a `BaseOutputTransport` issue that was causing non-audio frames being
Expand Down
112 changes: 81 additions & 31 deletions src/pipecat/pipeline/parallel_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@

import asyncio
from itertools import chain
from typing import Awaitable, Callable, List
from typing import Awaitable, Callable, Dict, List

from loguru import logger

from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame, SystemFrame
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
StartFrame,
StartInterruptionFrame,
SystemFrame,
)
from pipecat.pipeline.base_pipeline import BasePipeline
from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
Expand Down Expand Up @@ -72,11 +79,10 @@ def __init__(self, *args):
self._sources = []
self._sinks = []
self._seen_ids = set()
self._endframe_counter: Dict[int, int] = {}

self._up_queue = asyncio.Queue()
self._down_queue = asyncio.Queue()
self._up_task: asyncio.Task | None = None
self._down_task: asyncio.Task | None = None

self._pipelines = []

Expand Down Expand Up @@ -111,18 +117,19 @@ def processors_with_metrics(self) -> List[FrameProcessor]:
#

async def cleanup(self):
await asyncio.gather(*[s.cleanup() for s in self._sources])
await asyncio.gather(*[p.cleanup() for p in self._pipelines])

async def _start_tasks(self):
loop = self.get_event_loop()
self._up_task = loop.create_task(self._process_up_queue())
self._down_task = loop.create_task(self._process_down_queue())
await asyncio.gather(*[s.cleanup() for s in self._sinks])

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, StartFrame):
await self._start_tasks()
await self._start()
elif isinstance(frame, EndFrame):
self._endframe_counter[frame.id] = len(self._pipelines)
elif isinstance(frame, CancelFrame):
await self._cancel()

if direction == FrameDirection.UPSTREAM:
# If we get an upstream frame we process it in each sink.
Expand All @@ -131,36 +138,79 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
# If we get a downstream frame we process it in each source.
await asyncio.gather(*[s.queue_frame(frame, direction) for s in self._sources])

# If we get an EndFrame we stop our queue processing tasks and wait on
# all the pipelines to finish.
if isinstance(frame, (CancelFrame, EndFrame)):
# Use None to indicate when queues should be done processing.
await self._up_queue.put(None)
await self._down_queue.put(None)
if self._up_task:
await self._up_task
if self._down_task:
await self._down_task
# Handle interruptions after everything has been cancelled.
if isinstance(frame, StartInterruptionFrame):
await self._handle_interruption()
# Wait for tasks to finish.
elif isinstance(frame, EndFrame):
await self._stop()

async def _start(self):
await self._create_tasks()

async def _stop(self):
# The up task doesn't receive an EndFrame, so we just cancel it.
self._up_task.cancel()
await self._up_task
# The down tasks waits for the last EndFrame send by the internal
# pipelines.
await self._down_task

async def _cancel(self):
self._up_task.cancel()
await self._up_task
self._down_task.cancel()
await self._down_task

async def _create_tasks(self):
loop = self.get_event_loop()
self._up_task = loop.create_task(self._process_up_queue())
self._down_task = loop.create_task(self._process_down_queue())

async def _drain_queues(self):
while not self._up_queue.empty:
await self._up_queue.get()
while not self._down_queue.empty:
await self._down_queue.get()

async def _handle_interruption(self):
await self._cancel()
await self._drain_queues()
await self._create_tasks()

async def _parallel_push_frame(self, frame: Frame, direction: FrameDirection):
if frame.id not in self._seen_ids:
self._seen_ids.add(frame.id)
await self.push_frame(frame, direction)

async def _process_up_queue(self):
running = True
while running:
frame = await self._up_queue.get()
if frame:
while True:
try:
frame = await self._up_queue.get()
await self._parallel_push_frame(frame, FrameDirection.UPSTREAM)
running = frame is not None
self._up_queue.task_done()
self._up_queue.task_done()
except asyncio.CancelledError:
break

async def _process_down_queue(self):
running = True
while running:
frame = await self._down_queue.get()
if frame:
await self._parallel_push_frame(frame, FrameDirection.DOWNSTREAM)
running = frame is not None
self._down_queue.task_done()
try:
frame = await self._down_queue.get()

endframe_counter = self._endframe_counter.get(frame.id, 0)

# If we have a counter, decrement it.
if endframe_counter > 0:
self._endframe_counter[frame.id] -= 1
endframe_counter = self._endframe_counter[frame.id]

# If we don't have a counter or we reached 0, push the frame.
if endframe_counter == 0:
await self._parallel_push_frame(frame, FrameDirection.DOWNSTREAM)

running = not (endframe_counter == 0 and isinstance(frame, EndFrame))

self._down_queue.task_done()
except asyncio.CancelledError:
break
Loading