Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pipeline: add UserTranscriptionAggregator #111

Merged
merged 2 commits into from
Apr 9, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions src/dailyai/pipeline/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,79 @@
from typing import AsyncGenerator, Coroutine, List


class BasicResponseAggregator(FrameProcessor):
"""This frame processor aggregates frames between a start and an end frame
into complete text frame sentences.

For example, frame input/output:
UserStartedSpeakingFrame() -> None
TranscriptionFrame("Hello,") -> None
TranscriptionFrame(" world.") -> None
UserStoppedSpeakingFrame() -> TextFrame("Hello world.")

Doctest:
>>> async def print_frames(aggregator, frame):
... async for frame in aggregator.process_frame(frame):
... if isinstance(frame, TextFrame):
... print(frame.text)

>>> aggregator = BasicResponseAggregator(start_frame = UserStartedSpeakingFrame,
... end_frame=UserStoppedSpeakingFrame,
... accumulator_frame=TranscriptionFrame,
... pass_through=False)
>>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame()))
>>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1)))
>>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2)))
>>> asyncio.run(print_frames(aggregator, UserStoppedSpeakingFrame()))
Hello, world.

"""

def __init__(
self,
*,
start_frame,
end_frame,
accumulator_frame,
pass_through=True,
):
self.aggregation = ""
self.aggregating = False
self._start_frame = start_frame
self._end_frame = end_frame
self._accumulator_frame = accumulator_frame
self._pass_through = pass_through

async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
if isinstance(frame, self._start_frame):
self.aggregating = True
elif isinstance(frame, self._end_frame):
self.aggregating = False
# Sometimes VAD triggers quickly on and off. If we don't get any transcription,
# it creates empty LLM message queue frames
if len(self.aggregation) > 0:
output = self.aggregation
self.aggregation = ""
yield self._end_frame()
yield TextFrame(output.strip())
elif isinstance(frame, self._accumulator_frame) and self.aggregating:
self.aggregation += f" {frame.text}"
if self._pass_through:
yield frame
else:
yield frame


class UserTranscriptionAggregator(BasicResponseAggregator):
def __init__(self):
super().__init__(
start_frame=UserStartedSpeakingFrame,
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
pass_through=False,
)


class ResponseAggregator(FrameProcessor):

def __init__(
Expand Down
Loading