diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py index 5da4fae16..cca92c906 100644 --- a/examples/foundational/07-interruptible.py +++ b/examples/foundational/07-interruptible.py @@ -3,8 +3,8 @@ import logging import os from dailyai.pipeline.aggregators import ( - LLMResponseAggregator, - UserResponseAggregator, + LLMAssistantResponseAggregator, + LLMUserResponseAggregator, ) from dailyai.pipeline.pipeline import Pipeline @@ -63,8 +63,8 @@ async def run_conversation(): await transport.run_interruptible_pipeline( pipeline, - post_processor=LLMResponseAggregator(messages), - pre_processor=UserResponseAggregator(messages), + post_processor=LLMAssistantResponseAggregator(messages), + pre_processor=LLMUserResponseAggregator(messages), ) transport.transcription_settings["extra"]["punctuate"] = False diff --git a/examples/starter-apps/chatbot.py b/examples/starter-apps/chatbot.py index 666c624c2..8f7e0237f 100644 --- a/examples/starter-apps/chatbot.py +++ b/examples/starter-apps/chatbot.py @@ -6,8 +6,8 @@ from typing import AsyncGenerator from dailyai.pipeline.aggregators import ( - LLMResponseAggregator, - UserResponseAggregator, + LLMAssistantResponseAggregator, + LLMUserResponseAggregator, ) from dailyai.pipeline.frames import ( ImageFrame, @@ -135,8 +135,8 @@ async def run_conversation(): await transport.run_interruptible_pipeline( pipeline, - post_processor=LLMResponseAggregator(messages), - pre_processor=UserResponseAggregator(messages), + post_processor=LLMAssistantResponseAggregator(messages), + pre_processor=LLMUserResponseAggregator(messages), ) transport.transcription_settings["extra"]["endpointing"] = True diff --git a/examples/starter-apps/storybot.py b/examples/starter-apps/storybot.py index f7bb15971..f04435eca 100644 --- a/examples/starter-apps/storybot.py +++ b/examples/starter-apps/storybot.py @@ -19,8 +19,8 @@ from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService from dailyai.pipeline.aggregators import ( LLMAssistantContextAggregator, - UserResponseAggregator, - LLMResponseAggregator, + LLMAssistantResponseAggregator, + LLMUserResponseAggregator, ) from dailyai.pipeline.frames import ( EndPipeFrame, @@ -209,8 +209,8 @@ async def main(room_url: str, token): key_id=os.getenv("FAL_KEY_ID"), key_secret=os.getenv("FAL_KEY_SECRET"), ) - lra = LLMResponseAggregator(messages) - ura = UserResponseAggregator(messages) + lra = LLMAssistantResponseAggregator(messages) + ura = LLMUserResponseAggregator(messages) sp = StoryProcessor(messages, story) sig = StoryImageGenerator(story, llm, img) diff --git a/src/dailyai/pipeline/aggregators.py b/src/dailyai/pipeline/aggregators.py index bbed2fbc3..543db599f 100644 --- a/src/dailyai/pipeline/aggregators.py +++ b/src/dailyai/pipeline/aggregators.py @@ -22,6 +22,79 @@ class ResponseAggregator(FrameProcessor): + """This frame processor aggregates frames between a start and an end frame + into complete text frame sentences. + + For example, frame input/output: + UserStartedSpeakingFrame() -> None + TranscriptionFrame("Hello,") -> None + TranscriptionFrame(" world.") -> None + UserStoppedSpeakingFrame() -> TextFrame("Hello world.") + + Doctest: + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... if isinstance(frame, TextFrame): + ... print(frame.text) + + >>> aggregator = ResponseAggregator(start_frame = UserStartedSpeakingFrame, + ... end_frame=UserStoppedSpeakingFrame, + ... accumulator_frame=TranscriptionFrame, + ... pass_through=False) + >>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame())) + >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1))) + >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2))) + >>> asyncio.run(print_frames(aggregator, UserStoppedSpeakingFrame())) + Hello, world. + + """ + + def __init__( + self, + *, + start_frame, + end_frame, + accumulator_frame, + pass_through=True, + ): + self.aggregation = "" + self.aggregating = False + self._start_frame = start_frame + self._end_frame = end_frame + self._accumulator_frame = accumulator_frame + self._pass_through = pass_through + + async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: + if isinstance(frame, self._start_frame): + self.aggregating = True + elif isinstance(frame, self._end_frame): + self.aggregating = False + # Sometimes VAD triggers quickly on and off. If we don't get any transcription, + # it creates empty LLM message queue frames + if len(self.aggregation) > 0: + output = self.aggregation + self.aggregation = "" + yield self._end_frame() + yield TextFrame(output.strip()) + elif isinstance(frame, self._accumulator_frame) and self.aggregating: + self.aggregation += f" {frame.text}" + if self._pass_through: + yield frame + else: + yield frame + + +class UserResponseAggregator(ResponseAggregator): + def __init__(self): + super().__init__( + start_frame=UserStartedSpeakingFrame, + end_frame=UserStoppedSpeakingFrame, + accumulator_frame=TranscriptionFrame, + pass_through=False, + ) + + +class LLMResponseAggregator(FrameProcessor): def __init__( self, @@ -66,7 +139,7 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: yield frame -class LLMResponseAggregator(ResponseAggregator): +class LLMAssistantResponseAggregator(LLMResponseAggregator): def __init__(self, messages: list[dict]): super().__init__( messages=messages, @@ -77,7 +150,7 @@ def __init__(self, messages: list[dict]): ) -class UserResponseAggregator(ResponseAggregator): +class LLMUserResponseAggregator(LLMResponseAggregator): def __init__(self, messages: list[dict]): super().__init__( messages=messages,