From 8ab86abae2ca7ef0915fd13be16471dd64844d3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Wed, 10 Apr 2024 23:24:05 -0700 Subject: [PATCH] aggregators: allow TranscriptionFrame after an end frame threshold --- src/dailyai/pipeline/aggregators.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/dailyai/pipeline/aggregators.py b/src/dailyai/pipeline/aggregators.py index 78b75afbe..956e5149c 100644 --- a/src/dailyai/pipeline/aggregators.py +++ b/src/dailyai/pipeline/aggregators.py @@ -1,5 +1,6 @@ import asyncio import re +import time from dailyai.pipeline.frame_processor import FrameProcessor @@ -107,6 +108,7 @@ def __init__( end_frame, accumulator_frame, pass_through=True, + end_frame_threshold=0.75, ): self.aggregation = "" self.aggregating = False @@ -116,6 +118,8 @@ def __init__( self._end_frame = end_frame self._accumulator_frame = accumulator_frame self._pass_through = pass_through + self._end_frame_threshold = end_frame_threshold + self._last_end_frame_time = 0 async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: if not self.messages: @@ -128,14 +132,21 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: # Sometimes VAD triggers quickly on and off. If we don't get any transcription, # it creates empty LLM message queue frames if len(self.aggregation) > 0: - self.messages.append( - {"role": self._role, "content": self.aggregation}) + self.messages.append({"role": self._role, "content": self.aggregation}) self.aggregation = "" yield self._end_frame() yield LLMMessagesFrame(self.messages) + self._last_end_frame_time = time.time() elif isinstance(frame, self._accumulator_frame): + # Also accept transcription frames received for a short period after + # the last end frame was received. + diff_time = time.time() - self._last_end_frame_time if self.aggregating: self.aggregation += f" {frame.text}" + elif diff_time <= self._end_frame_threshold: + self.messages.append({"role": self._role, "content": frame.text}) + yield self._end_frame() + yield LLMMessagesFrame(self.messages) if self._pass_through: yield frame else: