Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pipeline: add UserTranscriptionAggregator #111

Merged
merged 2 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/foundational/07-interruptible.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
import os
from dailyai.pipeline.aggregators import (
LLMResponseAggregator,
UserResponseAggregator,
LLMAssistantResponseAggregator,
LLMUserResponseAggregator,
)

from dailyai.pipeline.pipeline import Pipeline
Expand Down Expand Up @@ -63,8 +63,8 @@ async def run_conversation():

await transport.run_interruptible_pipeline(
pipeline,
post_processor=LLMResponseAggregator(messages),
pre_processor=UserResponseAggregator(messages),
post_processor=LLMAssistantResponseAggregator(messages),
pre_processor=LLMUserResponseAggregator(messages),
)

transport.transcription_settings["extra"]["punctuate"] = False
Expand Down
8 changes: 4 additions & 4 deletions examples/starter-apps/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from typing import AsyncGenerator

from dailyai.pipeline.aggregators import (
LLMResponseAggregator,
UserResponseAggregator,
LLMAssistantResponseAggregator,
LLMUserResponseAggregator,
)
from dailyai.pipeline.frames import (
ImageFrame,
Expand Down Expand Up @@ -135,8 +135,8 @@ async def run_conversation():

await transport.run_interruptible_pipeline(
pipeline,
post_processor=LLMResponseAggregator(messages),
pre_processor=UserResponseAggregator(messages),
post_processor=LLMAssistantResponseAggregator(messages),
pre_processor=LLMUserResponseAggregator(messages),
)

transport.transcription_settings["extra"]["endpointing"] = True
Expand Down
8 changes: 4 additions & 4 deletions examples/starter-apps/storybot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
from dailyai.pipeline.aggregators import (
LLMAssistantContextAggregator,
UserResponseAggregator,
LLMResponseAggregator,
LLMAssistantResponseAggregator,
LLMUserResponseAggregator,
)
from dailyai.pipeline.frames import (
EndPipeFrame,
Expand Down Expand Up @@ -209,8 +209,8 @@ async def main(room_url: str, token):
key_id=os.getenv("FAL_KEY_ID"),
key_secret=os.getenv("FAL_KEY_SECRET"),
)
lra = LLMResponseAggregator(messages)
ura = UserResponseAggregator(messages)
lra = LLMAssistantResponseAggregator(messages)
ura = LLMUserResponseAggregator(messages)
sp = StoryProcessor(messages, story)
sig = StoryImageGenerator(story, llm, img)

Expand Down
77 changes: 75 additions & 2 deletions src/dailyai/pipeline/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,79 @@


class ResponseAggregator(FrameProcessor):
"""This frame processor aggregates frames between a start and an end frame
into complete text frame sentences.

For example, frame input/output:
UserStartedSpeakingFrame() -> None
TranscriptionFrame("Hello,") -> None
TranscriptionFrame(" world.") -> None
UserStoppedSpeakingFrame() -> TextFrame("Hello world.")

Doctest:
>>> async def print_frames(aggregator, frame):
... async for frame in aggregator.process_frame(frame):
... if isinstance(frame, TextFrame):
... print(frame.text)

>>> aggregator = ResponseAggregator(start_frame = UserStartedSpeakingFrame,
... end_frame=UserStoppedSpeakingFrame,
... accumulator_frame=TranscriptionFrame,
... pass_through=False)
>>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame()))
>>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1)))
>>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2)))
>>> asyncio.run(print_frames(aggregator, UserStoppedSpeakingFrame()))
Hello, world.

"""

def __init__(
self,
*,
start_frame,
end_frame,
accumulator_frame,
pass_through=True,
):
self.aggregation = ""
self.aggregating = False
self._start_frame = start_frame
self._end_frame = end_frame
self._accumulator_frame = accumulator_frame
self._pass_through = pass_through

async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
if isinstance(frame, self._start_frame):
self.aggregating = True
elif isinstance(frame, self._end_frame):
self.aggregating = False
# Sometimes VAD triggers quickly on and off. If we don't get any transcription,
# it creates empty LLM message queue frames
if len(self.aggregation) > 0:
output = self.aggregation
self.aggregation = ""
yield self._end_frame()
yield TextFrame(output.strip())
elif isinstance(frame, self._accumulator_frame) and self.aggregating:
self.aggregation += f" {frame.text}"
if self._pass_through:
yield frame
else:
yield frame


class UserResponseAggregator(ResponseAggregator):
def __init__(self):
super().__init__(
start_frame=UserStartedSpeakingFrame,
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
pass_through=False,
)


class LLMResponseAggregator(FrameProcessor):

def __init__(
self,
Expand Down Expand Up @@ -66,7 +139,7 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
yield frame


class LLMResponseAggregator(ResponseAggregator):
class LLMAssistantResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: list[dict]):
super().__init__(
messages=messages,
Expand All @@ -77,7 +150,7 @@ def __init__(self, messages: list[dict]):
)


class UserResponseAggregator(ResponseAggregator):
class LLMUserResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: list[dict]):
super().__init__(
messages=messages,
Expand Down
Loading