Skip to content

Commit

Permalink
pipeline: add UserTranscriptionAggregator
Browse files Browse the repository at this point in the history
  • Loading branch information
aconchillo committed Apr 9, 2024
1 parent 53930b4 commit d7b2e67
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions src/dailyai/pipeline/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,79 @@
from typing import AsyncGenerator, Coroutine, List


class BasicResponseAggregator(FrameProcessor):
"""This frame processor aggregates frames between a start and an end frame
into complete text frame sentences.
For example, frame input/output:
UserStartedSpeakingFrame() -> None
TranscriptionFrame("Hello,") -> None
TranscriptionFrame(" world.") -> None
UserStoppedSpeakingFrame() -> TextFrame("Hello world.")
Doctest:
>>> async def print_frames(aggregator, frame):
... async for frame in aggregator.process_frame(frame):
... if isinstance(frame, TextFrame):
... print(frame.text)
>>> aggregator = BasicResponseAggregator(start_frame = UserStartedSpeakingFrame,
... end_frame=UserStoppedSpeakingFrame,
... accumulator_frame=TranscriptionFrame,
... pass_through=False)
>>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame()))
>>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1)))
>>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2)))
>>> asyncio.run(print_frames(aggregator, UserStoppedSpeakingFrame()))
Hello, world.
"""

def __init__(
self,
*,
start_frame,
end_frame,
accumulator_frame,
pass_through=True,
):
self.aggregation = ""
self.aggregating = False
self._start_frame = start_frame
self._end_frame = end_frame
self._accumulator_frame = accumulator_frame
self._pass_through = pass_through

async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
if isinstance(frame, self._start_frame):
self.aggregating = True
elif isinstance(frame, self._end_frame):
self.aggregating = False
# Sometimes VAD triggers quickly on and off. If we don't get any transcription,
# it creates empty LLM message queue frames
if len(self.aggregation) > 0:
output = self.aggregation
self.aggregation = ""
yield self._end_frame()
yield TextFrame(output.strip())
elif isinstance(frame, self._accumulator_frame) and self.aggregating:
self.aggregation += f" {frame.text}"
if self._pass_through:
yield frame
else:
yield frame


class UserTranscriptionAggregator(BasicResponseAggregator):
def __init__(self):
super().__init__(
start_frame=UserStartedSpeakingFrame,
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
pass_through=False,
)


class ResponseAggregator(FrameProcessor):

def __init__(
Expand Down

0 comments on commit d7b2e67

Please sign in to comment.