From d7b2e67c35baab21ded34db3c25b9cb330b7c920 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 8 Apr 2024 17:15:14 -0700 Subject: [PATCH] pipeline: add UserTranscriptionAggregator --- src/dailyai/pipeline/aggregators.py | 73 +++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/src/dailyai/pipeline/aggregators.py b/src/dailyai/pipeline/aggregators.py index bbed2fbc3..6e8e82b9b 100644 --- a/src/dailyai/pipeline/aggregators.py +++ b/src/dailyai/pipeline/aggregators.py @@ -21,6 +21,79 @@ from typing import AsyncGenerator, Coroutine, List +class BasicResponseAggregator(FrameProcessor): + """This frame processor aggregates frames between a start and an end frame + into complete text frame sentences. + + For example, frame input/output: + UserStartedSpeakingFrame() -> None + TranscriptionFrame("Hello,") -> None + TranscriptionFrame(" world.") -> None + UserStoppedSpeakingFrame() -> TextFrame("Hello world.") + + Doctest: + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... if isinstance(frame, TextFrame): + ... print(frame.text) + + >>> aggregator = BasicResponseAggregator(start_frame = UserStartedSpeakingFrame, + ... end_frame=UserStoppedSpeakingFrame, + ... accumulator_frame=TranscriptionFrame, + ... pass_through=False) + >>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame())) + >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1))) + >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2))) + >>> asyncio.run(print_frames(aggregator, UserStoppedSpeakingFrame())) + Hello, world. + + """ + + def __init__( + self, + *, + start_frame, + end_frame, + accumulator_frame, + pass_through=True, + ): + self.aggregation = "" + self.aggregating = False + self._start_frame = start_frame + self._end_frame = end_frame + self._accumulator_frame = accumulator_frame + self._pass_through = pass_through + + async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: + if isinstance(frame, self._start_frame): + self.aggregating = True + elif isinstance(frame, self._end_frame): + self.aggregating = False + # Sometimes VAD triggers quickly on and off. If we don't get any transcription, + # it creates empty LLM message queue frames + if len(self.aggregation) > 0: + output = self.aggregation + self.aggregation = "" + yield self._end_frame() + yield TextFrame(output.strip()) + elif isinstance(frame, self._accumulator_frame) and self.aggregating: + self.aggregation += f" {frame.text}" + if self._pass_through: + yield frame + else: + yield frame + + +class UserTranscriptionAggregator(BasicResponseAggregator): + def __init__(self): + super().__init__( + start_frame=UserStartedSpeakingFrame, + end_frame=UserStoppedSpeakingFrame, + accumulator_frame=TranscriptionFrame, + pass_through=False, + ) + + class ResponseAggregator(FrameProcessor): def __init__(