diff --git a/src/pipecat/pipeline/base_pipeline.py b/src/pipecat/pipeline/base_pipeline.py new file mode 100644 index 000000000..83dbc2730 --- /dev/null +++ b/src/pipecat/pipeline/base_pipeline.py @@ -0,0 +1,21 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from abc import abstractmethod + +from typing import List + +from pipecat.processors.frame_processor import FrameProcessor + + +class BasePipeline(FrameProcessor): + + def __init__(self): + super().__init__() + + @abstractmethod + def services(self) -> List[FrameProcessor]: + pass diff --git a/src/pipecat/pipeline/parallel_pipeline.py b/src/pipecat/pipeline/parallel_pipeline.py index ccf72bd90..f89e61ca7 100644 --- a/src/pipecat/pipeline/parallel_pipeline.py +++ b/src/pipecat/pipeline/parallel_pipeline.py @@ -6,6 +6,10 @@ import asyncio +from itertools import chain +from typing import List + +from pipecat.pipeline.base_pipeline import BasePipeline from pipecat.pipeline.pipeline import Pipeline from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.frames.frames import CancelFrame, EndFrame, Frame, StartFrame @@ -45,7 +49,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self._down_queue.put(frame) -class ParallelPipeline(FrameProcessor): +class ParallelPipeline(BasePipeline): def __init__(self, *args): super().__init__() @@ -81,6 +85,13 @@ def __init__(self, *args): logger.debug(f"Finished creating {self} pipelines") + # + # BasePipeline + # + + def services(self) -> List[FrameProcessor]: + return list(chain.from_iterable(p.services() for p in self._pipelines)) + # # Frame processor # diff --git a/src/pipecat/pipeline/parallel_task.py b/src/pipecat/pipeline/parallel_task.py index ce341bde5..b1f3c520c 100644 --- a/src/pipecat/pipeline/parallel_task.py +++ b/src/pipecat/pipeline/parallel_task.py @@ -6,8 +6,10 @@ import asyncio +from itertools import chain from typing import List +from pipecat.pipeline.base_pipeline import BasePipeline from pipecat.pipeline.pipeline import Pipeline from pipecat.processors.frame_processor import FrameDirection, FrameProcessor from pipecat.frames.frames import Frame @@ -47,7 +49,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self._down_queue.put(frame) -class ParallelTask(FrameProcessor): +class ParallelTask(BasePipeline): def __init__(self, *args): super().__init__() @@ -79,6 +81,13 @@ def __init__(self, *args): self._pipelines.append(pipeline) logger.debug(f"Finished creating {self} pipelines") + # + # BasePipeline + # + + def services(self) -> List[FrameProcessor]: + return list(chain.from_iterable(p.services() for p in self._pipelines)) + # # Frame processor # diff --git a/src/pipecat/pipeline/pipeline.py b/src/pipecat/pipeline/pipeline.py index 2cb5b45d4..1d7e8a024 100644 --- a/src/pipecat/pipeline/pipeline.py +++ b/src/pipecat/pipeline/pipeline.py @@ -4,12 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import asyncio - from typing import Callable, Coroutine, List -from pipecat.frames.frames import Frame +from pipecat.frames.frames import Frame, MetricsFrame, StartFrame +from pipecat.pipeline.base_pipeline import BasePipeline from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.services.ai_services import AIService class PipelineSource(FrameProcessor): @@ -44,7 +44,7 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self._downstream_push_frame(frame, direction) -class Pipeline(FrameProcessor): +class Pipeline(BasePipeline): def __init__(self, processors: List[FrameProcessor]): super().__init__() @@ -57,6 +57,19 @@ def __init__(self, processors: List[FrameProcessor]): self._link_processors() + # + # BasePipeline + # + + def services(self): + services = [] + for p in self._processors: + if isinstance(p, AIService): + services.append(p) + elif isinstance(p, Pipeline): + services += p.services() + return services + # # Frame processor # @@ -67,6 +80,9 @@ async def cleanup(self): async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) + if isinstance(frame, StartFrame) and self.metrics_enabled: + await self._send_initial_metrics() + if direction == FrameDirection.DOWNSTREAM: await self._source.process_frame(frame, FrameDirection.DOWNSTREAM) elif direction == FrameDirection.UPSTREAM: @@ -81,3 +97,9 @@ def _link_processors(self): for curr in self._processors[1:]: prev.link(curr) prev = curr + + async def _send_initial_metrics(self): + services = self.services() + ttfb = dict(zip([s.name for s in services], [0] * len(services))) + frame = MetricsFrame(ttfb=ttfb) + await self._source.process_frame(frame, FrameDirection.DOWNSTREAM)