From d7b2e67c35baab21ded34db3c25b9cb330b7c920 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Mon, 8 Apr 2024 17:15:14 -0700 Subject: [PATCH 1/2] pipeline: add UserTranscriptionAggregator --- src/dailyai/pipeline/aggregators.py | 73 +++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/src/dailyai/pipeline/aggregators.py b/src/dailyai/pipeline/aggregators.py index bbed2fbc3..6e8e82b9b 100644 --- a/src/dailyai/pipeline/aggregators.py +++ b/src/dailyai/pipeline/aggregators.py @@ -21,6 +21,79 @@ from typing import AsyncGenerator, Coroutine, List +class BasicResponseAggregator(FrameProcessor): + """This frame processor aggregates frames between a start and an end frame + into complete text frame sentences. + + For example, frame input/output: + UserStartedSpeakingFrame() -> None + TranscriptionFrame("Hello,") -> None + TranscriptionFrame(" world.") -> None + UserStoppedSpeakingFrame() -> TextFrame("Hello world.") + + Doctest: + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... if isinstance(frame, TextFrame): + ... print(frame.text) + + >>> aggregator = BasicResponseAggregator(start_frame = UserStartedSpeakingFrame, + ... end_frame=UserStoppedSpeakingFrame, + ... accumulator_frame=TranscriptionFrame, + ... pass_through=False) + >>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame())) + >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1))) + >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2))) + >>> asyncio.run(print_frames(aggregator, UserStoppedSpeakingFrame())) + Hello, world. + + """ + + def __init__( + self, + *, + start_frame, + end_frame, + accumulator_frame, + pass_through=True, + ): + self.aggregation = "" + self.aggregating = False + self._start_frame = start_frame + self._end_frame = end_frame + self._accumulator_frame = accumulator_frame + self._pass_through = pass_through + + async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: + if isinstance(frame, self._start_frame): + self.aggregating = True + elif isinstance(frame, self._end_frame): + self.aggregating = False + # Sometimes VAD triggers quickly on and off. If we don't get any transcription, + # it creates empty LLM message queue frames + if len(self.aggregation) > 0: + output = self.aggregation + self.aggregation = "" + yield self._end_frame() + yield TextFrame(output.strip()) + elif isinstance(frame, self._accumulator_frame) and self.aggregating: + self.aggregation += f" {frame.text}" + if self._pass_through: + yield frame + else: + yield frame + + +class UserTranscriptionAggregator(BasicResponseAggregator): + def __init__(self): + super().__init__( + start_frame=UserStartedSpeakingFrame, + end_frame=UserStoppedSpeakingFrame, + accumulator_frame=TranscriptionFrame, + pass_through=False, + ) + + class ResponseAggregator(FrameProcessor): def __init__( From 97b923e37eb64a92165f0d9240b566dde0a27740 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Tue, 9 Apr 2024 08:31:48 -0700 Subject: [PATCH 2/2] llm user and assistant aggregator renames --- examples/foundational/07-interruptible.py | 8 ++++---- examples/starter-apps/chatbot.py | 8 ++++---- examples/starter-apps/storybot.py | 8 ++++---- src/dailyai/pipeline/aggregators.py | 18 +++++++++--------- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/examples/foundational/07-interruptible.py b/examples/foundational/07-interruptible.py index 5da4fae16..cca92c906 100644 --- a/examples/foundational/07-interruptible.py +++ b/examples/foundational/07-interruptible.py @@ -3,8 +3,8 @@ import logging import os from dailyai.pipeline.aggregators import ( - LLMResponseAggregator, - UserResponseAggregator, + LLMAssistantResponseAggregator, + LLMUserResponseAggregator, ) from dailyai.pipeline.pipeline import Pipeline @@ -63,8 +63,8 @@ async def run_conversation(): await transport.run_interruptible_pipeline( pipeline, - post_processor=LLMResponseAggregator(messages), - pre_processor=UserResponseAggregator(messages), + post_processor=LLMAssistantResponseAggregator(messages), + pre_processor=LLMUserResponseAggregator(messages), ) transport.transcription_settings["extra"]["punctuate"] = False diff --git a/examples/starter-apps/chatbot.py b/examples/starter-apps/chatbot.py index 666c624c2..8f7e0237f 100644 --- a/examples/starter-apps/chatbot.py +++ b/examples/starter-apps/chatbot.py @@ -6,8 +6,8 @@ from typing import AsyncGenerator from dailyai.pipeline.aggregators import ( - LLMResponseAggregator, - UserResponseAggregator, + LLMAssistantResponseAggregator, + LLMUserResponseAggregator, ) from dailyai.pipeline.frames import ( ImageFrame, @@ -135,8 +135,8 @@ async def run_conversation(): await transport.run_interruptible_pipeline( pipeline, - post_processor=LLMResponseAggregator(messages), - pre_processor=UserResponseAggregator(messages), + post_processor=LLMAssistantResponseAggregator(messages), + pre_processor=LLMUserResponseAggregator(messages), ) transport.transcription_settings["extra"]["endpointing"] = True diff --git a/examples/starter-apps/storybot.py b/examples/starter-apps/storybot.py index f7bb15971..f04435eca 100644 --- a/examples/starter-apps/storybot.py +++ b/examples/starter-apps/storybot.py @@ -19,8 +19,8 @@ from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService from dailyai.pipeline.aggregators import ( LLMAssistantContextAggregator, - UserResponseAggregator, - LLMResponseAggregator, + LLMAssistantResponseAggregator, + LLMUserResponseAggregator, ) from dailyai.pipeline.frames import ( EndPipeFrame, @@ -209,8 +209,8 @@ async def main(room_url: str, token): key_id=os.getenv("FAL_KEY_ID"), key_secret=os.getenv("FAL_KEY_SECRET"), ) - lra = LLMResponseAggregator(messages) - ura = UserResponseAggregator(messages) + lra = LLMAssistantResponseAggregator(messages) + ura = LLMUserResponseAggregator(messages) sp = StoryProcessor(messages, story) sig = StoryImageGenerator(story, llm, img) diff --git a/src/dailyai/pipeline/aggregators.py b/src/dailyai/pipeline/aggregators.py index 6e8e82b9b..543db599f 100644 --- a/src/dailyai/pipeline/aggregators.py +++ b/src/dailyai/pipeline/aggregators.py @@ -21,7 +21,7 @@ from typing import AsyncGenerator, Coroutine, List -class BasicResponseAggregator(FrameProcessor): +class ResponseAggregator(FrameProcessor): """This frame processor aggregates frames between a start and an end frame into complete text frame sentences. @@ -37,10 +37,10 @@ class BasicResponseAggregator(FrameProcessor): ... if isinstance(frame, TextFrame): ... print(frame.text) - >>> aggregator = BasicResponseAggregator(start_frame = UserStartedSpeakingFrame, - ... end_frame=UserStoppedSpeakingFrame, - ... accumulator_frame=TranscriptionFrame, - ... pass_through=False) + >>> aggregator = ResponseAggregator(start_frame = UserStartedSpeakingFrame, + ... end_frame=UserStoppedSpeakingFrame, + ... accumulator_frame=TranscriptionFrame, + ... pass_through=False) >>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame())) >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1))) >>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2))) @@ -84,7 +84,7 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: yield frame -class UserTranscriptionAggregator(BasicResponseAggregator): +class UserResponseAggregator(ResponseAggregator): def __init__(self): super().__init__( start_frame=UserStartedSpeakingFrame, @@ -94,7 +94,7 @@ def __init__(self): ) -class ResponseAggregator(FrameProcessor): +class LLMResponseAggregator(FrameProcessor): def __init__( self, @@ -139,7 +139,7 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: yield frame -class LLMResponseAggregator(ResponseAggregator): +class LLMAssistantResponseAggregator(LLMResponseAggregator): def __init__(self, messages: list[dict]): super().__init__( messages=messages, @@ -150,7 +150,7 @@ def __init__(self, messages: list[dict]): ) -class UserResponseAggregator(ResponseAggregator): +class LLMUserResponseAggregator(LLMResponseAggregator): def __init__(self, messages: list[dict]): super().__init__( messages=messages,