Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM user frame processor with tests #450

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 128 additions & 1 deletion src/pipecat/processors/aggregators/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# SPDX-License-Identifier: BSD 2-Clause License
#

import sys
from typing import List

from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame, OpenAILLMContext
Expand Down Expand Up @@ -311,3 +310,131 @@ def __init__(self, context: OpenAILLMContext):
accumulator_frame=TranscriptionFrame,
interim_accumulator_frame=InterimTranscriptionFrame
)
#CUSTOM CODE: this variable remembers if we prompted the LLM
self.sent_aggregation_after_last_interruption = False

# Relevant functions:
# LLMContextAggregator.async def _push_aggregation(self)
# and
# def LLMResponseAggregator._reset(self):

# The original pipecat implementation is in:
# LLMResponseAggregator.process_frame

# Use cases implemented:
#
# S: Start, E: End, T: Transcription, I: Interim, X: Text
#
# S E -> None
# S T E -> T
# S I T E -> T
# S I E T -> T
# S I E I T -> T
# S E T -> T
# S E I T -> T
#
# S I E T1 I T2 -> T1
#
# and T2 would be dropped.

# We have:
# S = UserStartedSpeakingFrame,
# E = UserStoppedSpeakingFrame,
# T = TranscriptionFrame,
# I = InterimTranscriptionFrame

# Cases we want to handle:
# - Make sure we never delete some aggregation as it is something said by the user
# - Solves case: S T1 I E S T2 E where we lose T1
# - Solve case: S T E Bot T (without E S) as the VAD is not activated (yeah case)
# - Solve case: S E T1 T2 where T2 is lost. (variation from above)
# For the last case we also send StartInterruptionFrame for making sure that the reprompt of the LLM does not make weird repeating messages.

# So the cases would be:
# S E -> None
# S T E -> T
# S I T E -> T
# S I E T -> T
# S I E I T -> T
# S E T -> T
# S E I T -> T
# S T1 I E S T2 E -> (T1 T2)
# S I E T1 I T2 -> T1 Interruption T2
# S T1 E T2 -> T1 Interruption T2
# S E T1 B T2 -> T1 Bot Interruption T2
# S E T1 T2 -> T1 Interruption T2
# see the tests at test_LLM_user_context_aggregator
async def process_frame(self, frame: Frame, direction: FrameDirection):
await FrameProcessor.process_frame(self, frame, direction)

send_aggregation = False

if isinstance(frame, self._start_frame):
# CUSTOM CODE: dont _aggregation = ""
#self._aggregation = ""
self._aggregating = True
self._seen_start_frame = True
self._seen_end_frame = False
# CUSTOM CODE: _seen_interim_results should be updated by interimframe and accumulator frame only
#self._seen_interim_results = False
await self.push_frame(frame, direction)
elif isinstance(frame, self._end_frame):
self._seen_end_frame = True
self._seen_start_frame = False

# We might have received the end frame but we might still be
# aggregating (i.e. we have seen interim results but not the final
# text).
self._aggregating = self._seen_interim_results or len(self._aggregation) == 0

# Send the aggregation if we are not aggregating anymore (i.e. no
# more interim results received).
send_aggregation = not self._aggregating
await self.push_frame(frame, direction)
elif isinstance(frame, self._accumulator_frame):
# CUSTOM CODE: send interruption without VAD
if self.sent_aggregation_after_last_interruption:
await self.push_frame(StartInterruptionFrame())
self.sent_aggregation_after_last_interruption = False

# CUSTOM CODE: do not require _aggregating so we do not lose frames
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
# We have recevied a complete sentence, so if we have seen the
# end frame and we were still aggregating, it means we should
# send the aggregation.
# CUSTOM CODE: important thing is not see start frame and not end frame (so user is still speaking)
send_aggregation = not self._seen_start_frame
# We just got our final result, so let's reset interim results.
self._seen_interim_results = False
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
# CUSTOM CODE: send interruption without VAD
if self.sent_aggregation_after_last_interruption:
await self.push_frame(StartInterruptionFrame())
self.sent_aggregation_after_last_interruption = False
self._seen_interim_results = True
elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame):
# CUSTOM CODE: manage new interruptions
self.sent_aggregation_after_last_interruption = False
await self._push_aggregation()
# Reset anyways
self._reset()
await self.push_frame(frame, direction)
elif isinstance(frame, LLMMessagesAppendFrame):
self._messages.extend(frame.messages)
messages_frame = LLMMessagesFrame(self._messages)
await self.push_frame(messages_frame)
elif isinstance(frame, LLMMessagesUpdateFrame):
# We push the frame downstream so the assistant aggregator gets
# updated as well.
await self.push_frame(frame)
# We can now reset this one.
self._reset()
self._messages = frame.messages
messages_frame = LLMMessagesFrame(self._messages)
await self.push_frame(messages_frame)
else:
await self.push_frame(frame, direction)

if send_aggregation:
await self._push_aggregation()

155 changes: 155 additions & 0 deletions src/pipecat/processors/aggregators/test_LLM_user_context_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# tests/test_custom_user_context.py

"""Tests for CustomLLMUserContextAggregator"""

import unittest


from pipecat.frames.frames import (
Frame,
TranscriptionFrame,
InterimTranscriptionFrame,
StartInterruptionFrame,
StopInterruptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor

# Note that UserStartedSpeakingFrame always come with StartInterruptionFrame
# and UserStoppedSpeakingFrame always come with StopInterruptionFrame
# S E -> None
# S T E -> T
# S I T E -> T
# S I E T -> T
# S I E I T -> T
# S E T -> T
# S E I T -> T
# S T1 I E S T2 E -> (T1 T2)
# S I E T1 I T2 -> T1 Interruption T2
# S T1 E T2 -> T1 Interruption T2
# S E T1 B T2 -> T1 Bot Interruption T2
# S E T1 T2 -> T1 Interruption T2


class StoreFrameProcessor(FrameProcessor):
def __init__(self, storage: list[Frame]) -> None:
super().__init__()
self.storage = storage
async def process_frame(self, frame: Frame, direction: FrameDirection):
self.storage.append(frame)

async def make_test(frames_to_send, expected_returned_frames):
context_aggregator = LLMUserContextAggregator(OpenAILLMContext(
messages=[{"role": "", "content": ""}]
))
storage = []
storage_processor = StoreFrameProcessor(storage)
context_aggregator.link(storage_processor)
for frame in frames_to_send:
await context_aggregator.process_frame(frame, direction=FrameDirection.DOWNSTREAM)
print("storage")
for x in storage:
print(x)
print("expected_returned_frames")
for x in expected_returned_frames:
print(x)
assert len(storage) == len(expected_returned_frames)
for expected, real in zip(expected_returned_frames, storage):
assert isinstance(real, expected)
return storage

class TestFrameProcessing(unittest.IsolatedAsyncioTestCase):

# S E ->
async def test_s_e(self):
"""S E case"""
frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), StopInterruptionFrame(), UserStoppedSpeakingFrame()]
expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame]
await make_test(frames_to_send, expected_returned_frames)

# S T E -> T
async def test_s_t_e(self):
"""S T E case"""
frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), TranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame()]
expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame]
await make_test(frames_to_send, expected_returned_frames)

# S I T E -> T
async def test_s_i_t_e(self):
"""S I T E case"""
frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), TranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame()]
expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame]
await make_test(frames_to_send, expected_returned_frames)

# S I E T -> T
async def test_s_i_e_t(self):
"""S I E T case"""
frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(), TranscriptionFrame("", "", "")]
expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame]
await make_test(frames_to_send, expected_returned_frames)


# S I E I T -> T
async def test_s_i_e_i_t(self):
"""S I E I T case"""
frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), TranscriptionFrame("", "", "")]
expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame]
await make_test(frames_to_send, expected_returned_frames)

# S E T -> T
async def test_s_e_t(self):
"""S E case"""
frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), StopInterruptionFrame(), UserStoppedSpeakingFrame(), TranscriptionFrame("", "", "")]
expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame]
await make_test(frames_to_send, expected_returned_frames)

# S E I T -> T
async def test_s_e_i_t(self):
"""S E I T case"""
frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), StopInterruptionFrame(), UserStoppedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), TranscriptionFrame("", "", "")]
expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame]
await make_test(frames_to_send, expected_returned_frames)

# S T1 I E S T2 E -> (T1 T2)
async def test_s_t1_i_e_s_t2_e(self):
"""S T1 I E S T2 E case"""
frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), TranscriptionFrame("T1", "", ""), InterimTranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(),
StartInterruptionFrame(), UserStartedSpeakingFrame(), TranscriptionFrame("T2", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame()]
expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame,
StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame, OpenAILLMContextFrame]
result = await make_test(frames_to_send, expected_returned_frames)
assert result[-1].context.messages[-1]["content"] == " T1 T2"

# S I E T1 I T2 -> T1 Interruption T2
async def test_s_i_e_t1_i_t2(self):
"""S I E T1 I T2 case"""
frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), InterimTranscriptionFrame("", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(),
TranscriptionFrame("T1", "", ""), InterimTranscriptionFrame("", "", ""), TranscriptionFrame("T2", "", ""),]
expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame,
OpenAILLMContextFrame, StartInterruptionFrame, OpenAILLMContextFrame]
result = await make_test(frames_to_send, expected_returned_frames)
assert result[-1].context.messages[-1]["content"] == " T1 T2"

# S T1 E T2 -> T1 Interruption T2
async def test_s_t1_e_t2(self):
"""S T1 E T2 case"""
frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), TranscriptionFrame("T1", "", ""), StopInterruptionFrame(), UserStoppedSpeakingFrame(),
TranscriptionFrame("T2", "", ""),]
expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame,
OpenAILLMContextFrame, StartInterruptionFrame, OpenAILLMContextFrame]
result = await make_test(frames_to_send, expected_returned_frames)
assert result[-1].context.messages[-1]["content"] == " T1 T2"

# S E T1 T2 -> T1 Interruption T2
async def test_s_e_t1_t2(self):
"""S E T1 T2 case"""
frames_to_send = [StartInterruptionFrame(), UserStartedSpeakingFrame(), StopInterruptionFrame(), UserStoppedSpeakingFrame(),
TranscriptionFrame("T1", "", ""), TranscriptionFrame("T2", "", ""),]
expected_returned_frames = [StartInterruptionFrame, UserStartedSpeakingFrame, StopInterruptionFrame, UserStoppedSpeakingFrame,
OpenAILLMContextFrame, StartInterruptionFrame, OpenAILLMContextFrame]
result = await make_test(frames_to_send, expected_returned_frames)
assert result[-1].context.messages[-1]["content"] == " T1 T2"