From c3bfcbd562cdc8ce8452b335ee6c8b774664a0f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleix=20Conchillo=20Flaqu=C3=A9?= Date: Sun, 19 May 2024 10:20:17 -0700 Subject: [PATCH] aggregators: clear accumulated responses if interruption happens --- .../processors/aggregators/llm_response.py | 26 +++++++++++-------- .../processors/aggregators/user_response.py | 26 +++++++++++-------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 3b9c07fe6..853217064 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -13,6 +13,7 @@ LLMFullResponseEndFrame, LLMMessagesFrame, LLMResponseStartFrame, + StartInterruptionFrame, TextFrame, LLMResponseEndFrame, TranscriptionFrame, @@ -40,12 +41,9 @@ def __init__( self._end_frame = end_frame self._accumulator_frame = accumulator_frame self._interim_accumulator_frame = interim_accumulator_frame - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False - self._aggregation = "" - self._aggregating = False + # Reset our accumulator state. + self._reset() # # Frame processor @@ -96,6 +94,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): self._seen_interim_results = False elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): self._seen_interim_results = True + elif isinstance(frame, StartInterruptionFrame): + self._reset() + await self.push_frame(frame, direction) else: await self.push_frame(frame, direction) @@ -108,12 +109,15 @@ async def _push_aggregation(self): frame = LLMMessagesFrame(self._messages) await self.push_frame(frame) - # Reset - self._aggregation = "" - self._aggregating = False - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False + # Reset our accumulator state. + self._reset() + + def _reset(self): + self._aggregation = "" + self._aggregating = False + self._seen_start_frame = False + self._seen_end_frame = False + self._seen_interim_results = False class LLMAssistantResponseAggregator(LLMResponseAggregator): diff --git a/src/pipecat/processors/aggregators/user_response.py b/src/pipecat/processors/aggregators/user_response.py index 5c6520f1f..5b1a8e309 100644 --- a/src/pipecat/processors/aggregators/user_response.py +++ b/src/pipecat/processors/aggregators/user_response.py @@ -8,6 +8,7 @@ from pipecat.frames.frames import ( Frame, InterimTranscriptionFrame, + StartInterruptionFrame, TextFrame, TranscriptionFrame, UserStartedSpeakingFrame, @@ -56,12 +57,9 @@ def __init__( self._end_frame = end_frame self._accumulator_frame = accumulator_frame self._interim_accumulator_frame = interim_accumulator_frame - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False - self._aggregation = "" - self._aggregating = False + # Reset our accumulator state. + self._reset() # # Frame processor @@ -112,6 +110,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): self._seen_interim_results = False elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): self._seen_interim_results = True + elif isinstance(frame, StartInterruptionFrame): + self._reset() + await self.push_frame(frame, direction) else: await self.push_frame(frame, direction) @@ -122,12 +123,15 @@ async def _push_aggregation(self): if len(self._aggregation) > 0: await self.push_frame(TextFrame(self._aggregation.strip())) - # Reset - self._aggregation = "" - self._aggregating = False - self._seen_start_frame = False - self._seen_end_frame = False - self._seen_interim_results = False + # Reset our accumulator state. + self._reset() + + def _reset(self): + self._aggregation = "" + self._aggregating = False + self._seen_start_frame = False + self._seen_end_frame = False + self._seen_interim_results = False class UserResponseAggregator(ResponseAggregator):