Skip to content

Commit

Permalink
aggregators: clear accumulated responses if interruption happens
Browse files Browse the repository at this point in the history
  • Loading branch information
aconchillo committed May 19, 2024
1 parent c0d5054 commit c3bfcbd
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 22 deletions.
26 changes: 15 additions & 11 deletions src/pipecat/processors/aggregators/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LLMFullResponseEndFrame,
LLMMessagesFrame,
LLMResponseStartFrame,
StartInterruptionFrame,
TextFrame,
LLMResponseEndFrame,
TranscriptionFrame,
Expand Down Expand Up @@ -40,12 +41,9 @@ def __init__(
self._end_frame = end_frame
self._accumulator_frame = accumulator_frame
self._interim_accumulator_frame = interim_accumulator_frame
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False

self._aggregation = ""
self._aggregating = False
# Reset our accumulator state.
self._reset()

#
# Frame processor
Expand Down Expand Up @@ -96,6 +94,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
self._seen_interim_results = False
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
self._seen_interim_results = True
elif isinstance(frame, StartInterruptionFrame):
self._reset()
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)

Expand All @@ -108,12 +109,15 @@ async def _push_aggregation(self):
frame = LLMMessagesFrame(self._messages)
await self.push_frame(frame)

# Reset
self._aggregation = ""
self._aggregating = False
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False
# Reset our accumulator state.
self._reset()

def _reset(self):
self._aggregation = ""
self._aggregating = False
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False


class LLMAssistantResponseAggregator(LLMResponseAggregator):
Expand Down
26 changes: 15 additions & 11 deletions src/pipecat/processors/aggregators/user_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pipecat.frames.frames import (
Frame,
InterimTranscriptionFrame,
StartInterruptionFrame,
TextFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
Expand Down Expand Up @@ -56,12 +57,9 @@ def __init__(
self._end_frame = end_frame
self._accumulator_frame = accumulator_frame
self._interim_accumulator_frame = interim_accumulator_frame
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False

self._aggregation = ""
self._aggregating = False
# Reset our accumulator state.
self._reset()

#
# Frame processor
Expand Down Expand Up @@ -112,6 +110,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
self._seen_interim_results = False
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
self._seen_interim_results = True
elif isinstance(frame, StartInterruptionFrame):
self._reset()
await self.push_frame(frame, direction)
else:
await self.push_frame(frame, direction)

Expand All @@ -122,12 +123,15 @@ async def _push_aggregation(self):
if len(self._aggregation) > 0:
await self.push_frame(TextFrame(self._aggregation.strip()))

# Reset
self._aggregation = ""
self._aggregating = False
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False
# Reset our accumulator state.
self._reset()

def _reset(self):
self._aggregation = ""
self._aggregating = False
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False


class UserResponseAggregator(ResponseAggregator):
Expand Down

0 comments on commit c3bfcbd

Please sign in to comment.