From eb3bf117b111924d12bcb23e3066dd26c9719e72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Thu, 11 Apr 2024 11:21:42 -0700 Subject: [PATCH] use InterimTranscriptionFrame in LLMUserResponseAggregator --- src/dailyai/pipeline/aggregators.py | 77 ++++++++++++++++------- src/dailyai/pipeline/frames.py | 11 ++++ src/dailyai/transports/daily_transport.py | 9 ++- 3 files changed, 74 insertions(+), 23 deletions(-) diff --git a/src/dailyai/pipeline/aggregators.py b/src/dailyai/pipeline/aggregators.py index 79e02a69e..81ea5815c 100644 --- a/src/dailyai/pipeline/aggregators.py +++ b/src/dailyai/pipeline/aggregators.py @@ -9,6 +9,7 @@ EndPipeFrame, Frame, ImageFrame, + InterimTranscriptionFrame, LLMMessagesFrame, LLMResponseEndFrame, LLMResponseStartFrame, @@ -107,8 +108,8 @@ def __init__( start_frame, end_frame, accumulator_frame, + interim_accumulator_frame=None, pass_through=True, - end_frame_threshold=0.75, ): self.aggregation = "" self.aggregating = False @@ -117,42 +118,75 @@ def __init__( self._start_frame = start_frame self._end_frame = end_frame self._accumulator_frame = accumulator_frame + self._interim_accumulator_frame = interim_accumulator_frame self._pass_through = pass_through - self._end_frame_threshold = end_frame_threshold - self._last_end_frame_time = 0 - + self._seen_start_frame = False + self._seen_end_frame = False + self._seen_interim_results = False + + # Use cases implemented: + # + # S: Start, E: End, T: Transcription, I: Interim, X: Text + # + # S E -> None + # S T E -> X + # S I T E -> X + # S I E T -> X + # S I E I T -> X + # + # The following case would not be supported: + # + # S I E T1 I T2 -> X + # + # and T2 would be dropped. async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: if not self.messages: return + send_aggregation = False + if isinstance(frame, self._start_frame): + self._seen_start_frame = True self.aggregating = True elif isinstance(frame, self._end_frame): - self.aggregating = False - # Sometimes VAD triggers quickly on and off. If we don't get any transcription, - # it creates empty LLM message queue frames - if len(self.aggregation) > 0: - self.messages.append({"role": self._role, "content": self.aggregation}) - self.aggregation = "" - yield self._end_frame() - yield LLMMessagesFrame(self.messages) - self._last_end_frame_time = time.time() + self._seen_end_frame = True + + # We might have received the end frame but we might still be + # aggregating (i.e. we have seen interim results but not the final + # text). + self.aggregating = self._seen_interim_results + + # Send the aggregation if we are not aggregating anymore (i.e. no + # more interim results received). + send_aggregation = not self.aggregating elif isinstance(frame, self._accumulator_frame): - # Also accept transcription frames received for a short period after - # the last end frame was received. It might be that transcription - # frames are a bit delayed. - diff_time = time.time() - self._last_end_frame_time if self.aggregating: self.aggregation += f" {frame.text}" - elif diff_time <= self._end_frame_threshold: - self.messages.append({"role": self._role, "content": frame.text}) - yield self._end_frame() - yield LLMMessagesFrame(self.messages) + # We have receied a complete sentence, so if we have seen the + # end frame and we were still aggregating, it means we should + # send the aggregation. + send_aggregation = self._seen_end_frame + if self._pass_through: yield frame + + # We just got our final result, so let's reset interim results. + self._seen_interim_results = False + elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): + self._seen_interim_results = True else: yield frame + if send_aggregation and len(self.aggregation) > 0: + self.messages.append({"role": self._role, "content": self.aggregation}) + yield self._end_frame() + yield LLMMessagesFrame(self.messages) + # Reset + self.aggregation = "" + self._seen_start_frame = False + self._seen_end_frame = False + self._seen_interim_results = False + class LLMAssistantResponseAggregator(LLMResponseAggregator): def __init__(self, messages: list[dict]): @@ -173,6 +207,7 @@ def __init__(self, messages: list[dict]): start_frame=UserStartedSpeakingFrame, end_frame=UserStoppedSpeakingFrame, accumulator_frame=TranscriptionFrame, + interim_accumulator_frame=InterimTranscriptionFrame, pass_through=False, ) diff --git a/src/dailyai/pipeline/frames.py b/src/dailyai/pipeline/frames.py index 4322e65c8..28a920dd8 100644 --- a/src/dailyai/pipeline/frames.py +++ b/src/dailyai/pipeline/frames.py @@ -164,6 +164,17 @@ def __str__(self): return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}" +@dataclass() +class InterimTranscriptionFrame(TextFrame): + """A text frame with interim transcription-specific data. Will be placed in + the transport's receive queue when a participant speaks.""" + participantId: str + timestamp: str + + def __str__(self): + return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}" + + class TTSStartFrame(ControlFrame): """Used to indicate the beginning of a TTS response. Following AudioFrames are part of the TTS response until an TTEndFrame. These frames can be used diff --git a/src/dailyai/transports/daily_transport.py b/src/dailyai/transports/daily_transport.py index 9d32d501b..8424897f1 100644 --- a/src/dailyai/transports/daily_transport.py +++ b/src/dailyai/transports/daily_transport.py @@ -10,6 +10,7 @@ from typing import Any from dailyai.pipeline.frames import ( + InterimTranscriptionFrame, ReceivedAppMessageFrame, TranscriptionFrame, UserImageFrame, @@ -368,8 +369,12 @@ def on_transcription_message(self, message: dict): elif "session_id" in message: participantId = message["session_id"] if self._my_participant_id and participantId != self._my_participant_id: - frame = TranscriptionFrame( - message["text"], participantId, message["timestamp"]) + is_final = message["rawResponse"]["is_final"] + if is_final: + frame = TranscriptionFrame(message["text"], participantId, message["timestamp"]) + else: + frame = InterimTranscriptionFrame( + message["text"], participantId, message["timestamp"]) asyncio.run_coroutine_threadsafe( self.receive_queue.put(frame), self._loop)