diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index b806efad4..628f7369b 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -49,4 +49,4 @@ jobs: - name: Test with pytest run: | source .venv/bin/activate - pytest --ignore-glob="*to_be_updated*" --ignore-glob=*pipeline_source* src tests + pytest diff --git a/CHANGELOG.md b/CHANGELOG.md index ea272f9a1..7d5d028b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -75,6 +75,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed LLM response aggregators to support more uses cases such as delayed + transcriptions. + - Fixed an issue that could cause the bot to stop talking if there was a user interruption before getting any audio from the TTS service. diff --git a/README.md b/README.md index 2821e8f32..f3f3ff2e0 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ Available options include: | Transport | [Daily (WebRTC)](https://docs.pipecat.ai/server/services/transport/daily), WebSocket, Local | `pip install "pipecat-ai[daily]"` | | Video | [Tavus](https://docs.pipecat.ai/server/services/video/tavus), [Simli](https://docs.pipecat.ai/server/services/video/simli) | `pip install "pipecat-ai[tavus,simli]"` | | Vision & Image | [Moondream](https://docs.pipecat.ai/server/services/vision/moondream), [fal](https://docs.pipecat.ai/server/services/image-generation/fal) | `pip install "pipecat-ai[moondream]"` | -| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Noisereduce](https://docs.pipecat.ai/server/utilities/audio/noisereduce-filter) | `pip install "pipecat-ai[silero]"` | +| Audio Processing | [Silero VAD](https://docs.pipecat.ai/server/utilities/audio/silero-vad-analyzer), [Krisp](https://docs.pipecat.ai/server/utilities/audio/krisp-filter), [Koala](https://docs.pipecat.ai/server/utilities/audio/koala-filter) | `pip install "pipecat-ai[silero]"` | | Analytics & Metrics | [Canonical AI](https://docs.pipecat.ai/server/services/analytics/canonical), [Sentry](https://docs.pipecat.ai/server/services/analytics/sentry) | `pip install "pipecat-ai[canonical]"` | 📚 [View full services documentation →](https://docs.pipecat.ai/server/services/supported-services) diff --git a/pyproject.toml b/pyproject.toml index 1d8b84679..a67c9a8e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,10 @@ whisper = [ "faster-whisper~=1.1.0" ] where = ["src"] [tool.pytest.ini_options] +addopts = "--verbose --disable-warnings" +testpaths = ["tests"] pythonpath = ["src"] +asyncio_default_fixture_loop_scope = "function" [tool.setuptools_scm] local_scheme = "no-local-version" diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 479746471..e62c7d0b7 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -7,6 +7,7 @@ from typing import List, Type from pipecat.frames.frames import ( + BotInterruptionFrame, Frame, InterimTranscriptionFrame, LLMFullResponseEndFrame, @@ -40,6 +41,7 @@ def __init__( interim_accumulator_frame: Type[TextFrame] | None = None, handle_interruptions: bool = False, expect_stripped_words: bool = True, # if True, need to add spaces between words + interrupt_double_accumulator: bool = True, # if True, interrupt if two or more accumulators are received ): super().__init__() @@ -51,8 +53,8 @@ def __init__( self._interim_accumulator_frame = interim_accumulator_frame self._handle_interruptions = handle_interruptions self._expect_stripped_words = expect_stripped_words + self._interrupt_double_accumulator = interrupt_double_accumulator - # Reset our accumulator state. self._reset() @property @@ -69,21 +71,20 @@ def role(self): # Use cases implemented: # - # S: Start, E: End, T: Transcription, I: Interim, X: Text + # S: Start, E: End, T: Transcription, I: Interim # - # S E -> None - # S T E -> X - # S I T E -> X - # S I E T -> X - # S I E I T -> X - # S E T -> X - # S E I T -> X - # - # The following case would not be supported: - # - # S I E T1 I T2 -> X - # - # and T2 would be dropped. + # S E -> None -> User started speaking but no transcription. + # S T E -> T -> Transcription between user started and stopped speaking. + # S E T -> T -> Transcription after user stopped speaking. + # S I T E -> T -> Transcription between user started and stopped speaking (with interims). + # S I E T -> T -> Transcription after user stopped speaking (with interims). + # S I E I T -> T -> Transcription after user stopped speaking (with interims). + # S E I T -> T -> Transcription after user stopped speaking (with interims). + # S T1 I E S T2 E -> "T1 T2" -> Merge two transcriptions if we got a first interim. + # S I E T1 I T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription. + # S T1 E T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription. + # S E T1 B T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription. + # S E T1 T2 -> T1 [Interruption] T2 -> Single user started/stopped, double transcription. async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -91,11 +92,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): send_aggregation = False if isinstance(frame, self._start_frame): - self._aggregation = "" self._aggregating = True self._seen_start_frame = True self._seen_end_frame = False - self._seen_interim_results = False await self.push_frame(frame, direction) elif isinstance(frame, self._end_frame): self._seen_end_frame = True @@ -109,23 +108,36 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # Send the aggregation if we are not aggregating anymore (i.e. no # more interim results received). send_aggregation = not self._aggregating - await self.push_frame(frame, direction) elif isinstance(frame, self._accumulator_frame): - if self._aggregating: - if self._expect_stripped_words: - self._aggregation += f" {frame.text}" if self._aggregation else frame.text - else: - self._aggregation += frame.text - # We have recevied a complete sentence, so if we have seen the - # end frame and we were still aggregating, it means we should - # send the aggregation. - send_aggregation = self._seen_end_frame + if ( + self._interrupt_double_accumulator + and self._sent_aggregation_after_last_interruption + ): + await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) + self._sent_aggregation_after_last_interruption = False + + if self._expect_stripped_words: + self._aggregation += f" {frame.text}" if self._aggregation else frame.text + else: + self._aggregation += frame.text + + # If we haven't seen the start frame but we got an accumulator frame + # it means two things: it was develiver before the end frame or it + # was delivered late. In both cases so we want to send the + # aggregation. + send_aggregation = not self._seen_start_frame # We just got our final result, so let's reset interim results. self._seen_interim_results = False elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame): + if ( + self._interrupt_double_accumulator + and self._sent_aggregation_after_last_interruption + ): + await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM) + self._sent_aggregation_after_last_interruption = False self._seen_interim_results = True - elif self._handle_interruptions and isinstance(frame, StartInterruptionFrame): + elif isinstance(frame, StartInterruptionFrame) and self._handle_interruptions: await self._push_aggregation() # Reset anyways self._reset() @@ -142,6 +154,9 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): if send_aggregation: await self._push_aggregation() + if isinstance(frame, self._end_frame): + await self.push_frame(frame, direction) + async def _push_aggregation(self): if len(self._aggregation) > 0: self._messages.append({"role": self._role, "content": self._aggregation}) @@ -150,6 +165,8 @@ async def _push_aggregation(self): # if the tasks gets cancelled we won't be able to clear things up. self._aggregation = "" + self._sent_aggregation_after_last_interruption = True + frame = LLMMessagesFrame(self._messages) await self.push_frame(frame) @@ -172,22 +189,11 @@ def _reset(self): self._seen_start_frame = False self._seen_end_frame = False self._seen_interim_results = False - - -class LLMAssistantResponseAggregator(LLMResponseAggregator): - def __init__(self, messages: List[dict] = []): - super().__init__( - messages=messages, - role="assistant", - start_frame=LLMFullResponseStartFrame, - end_frame=LLMFullResponseEndFrame, - accumulator_frame=TextFrame, - handle_interruptions=True, - ) + self._sent_aggregation_after_last_interruption = False class LLMUserResponseAggregator(LLMResponseAggregator): - def __init__(self, messages: List[dict] = []): + def __init__(self, messages: List[dict] = [], **kwargs): super().__init__( messages=messages, role="user", @@ -195,61 +201,21 @@ def __init__(self, messages: List[dict] = []): end_frame=UserStoppedSpeakingFrame, accumulator_frame=TranscriptionFrame, interim_accumulator_frame=InterimTranscriptionFrame, + **kwargs, ) -class LLMFullResponseAggregator(FrameProcessor): - """This class aggregates Text frames until it receives a - LLMFullResponseEndFrame, then emits the concatenated text as - a single text frame. - - given the following frames: - - TextFrame("Hello,") - TextFrame(" world.") - TextFrame(" I am") - TextFrame(" an LLM.") - LLMFullResponseEndFrame()] - - this processor will yield nothing for the first 4 frames, then - - TextFrame("Hello, world. I am an LLM.") - LLMFullResponseEndFrame() - - when passed the last frame. - - >>> async def print_frames(aggregator, frame): - ... async for frame in aggregator.process_frame(frame): - ... if isinstance(frame, TextFrame): - ... print(frame.text) - ... else: - ... print(frame.__class__.__name__) - - >>> aggregator = LLMFullResponseAggregator() - >>> asyncio.run(print_frames(aggregator, TextFrame("Hello,"))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" world."))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" I am"))) - >>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM."))) - >>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame())) - Hello, world. I am an LLM. - LLMFullResponseEndFrame - """ - - def __init__(self): - super().__init__() - self._aggregation = "" - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - if isinstance(frame, TextFrame): - self._aggregation += frame.text - elif isinstance(frame, LLMFullResponseEndFrame): - await self.push_frame(TextFrame(self._aggregation)) - await self.push_frame(frame) - self._aggregation = "" - else: - await self.push_frame(frame, direction) +class LLMAssistantResponseAggregator(LLMResponseAggregator): + def __init__(self, messages: List[dict] = [], **kwargs): + super().__init__( + messages=messages, + role="assistant", + start_frame=LLMFullResponseStartFrame, + end_frame=LLMFullResponseEndFrame, + accumulator_frame=TextFrame, + handle_interruptions=True, + **kwargs, + ) class LLMContextAggregator(LLMResponseAggregator): @@ -286,15 +252,14 @@ async def _push_aggregation(self): # if the tasks gets cancelled we won't be able to clear things up. self._aggregation = "" + self._sent_aggregation_after_last_interruption = True + frame = OpenAILLMContextFrame(self._context) await self.push_frame(frame) - # Reset our accumulator state. - self._reset() - class LLMAssistantContextAggregator(LLMContextAggregator): - def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True): + def __init__(self, context: OpenAILLMContext, **kwargs): super().__init__( messages=[], context=context, @@ -303,12 +268,12 @@ def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = T end_frame=LLMFullResponseEndFrame, accumulator_frame=TextFrame, handle_interruptions=True, - expect_stripped_words=expect_stripped_words, + **kwargs, ) class LLMUserContextAggregator(LLMContextAggregator): - def __init__(self, context: OpenAILLMContext): + def __init__(self, context: OpenAILLMContext, **kwargs): super().__init__( messages=[], context=context, @@ -317,4 +282,69 @@ def __init__(self, context: OpenAILLMContext): end_frame=UserStoppedSpeakingFrame, accumulator_frame=TranscriptionFrame, interim_accumulator_frame=InterimTranscriptionFrame, + **kwargs, ) + + +class LLMFullResponseAggregator(FrameProcessor): + """This class aggregates Text frames between LLMFullResponseStartFrame and + LLMFullResponseEndFrame, then emits the concatenated text as a single text + frame. + + given the following frames: + + LLMFullResponseStartFrame() + TextFrame("Hello,") + TextFrame(" world.") + TextFrame(" I am") + TextFrame(" an LLM.") + LLMFullResponseEndFrame() + + this processor will push, + + LLMFullResponseStartFrame() + TextFrame("Hello, world. I am an LLM.") + LLMFullResponseEndFrame() + + when passed the last frame. + + >>> async def print_frames(aggregator, frame): + ... async for frame in aggregator.process_frame(frame): + ... if isinstance(frame, TextFrame): + ... print(frame.text) + ... else: + ... print(frame.__class__.__name__) + + >>> aggregator = LLMFullResponseAggregator() + >>> asyncio.run(print_frames(aggregator, LLMFullResponseStartFrame())) + >>> asyncio.run(print_frames(aggregator, TextFrame("Hello,"))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" world."))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" I am"))) + >>> asyncio.run(print_frames(aggregator, TextFrame(" an LLM."))) + >>> asyncio.run(print_frames(aggregator, LLMFullResponseEndFrame())) + LLMFullResponseStartFrame + Hello, world. I am an LLM. + LLMFullResponseEndFrame + + """ + + def __init__(self): + super().__init__() + self._aggregation = "" + self._seen_start_frame = False + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, LLMFullResponseStartFrame): + self._seen_start_frame = True + await self.push_frame(frame, direction) + elif isinstance(frame, LLMFullResponseEndFrame): + self._seen_start_frame = False + await self.push_frame(TextFrame(self._aggregation)) + await self.push_frame(frame) + self._aggregation = "" + elif isinstance(frame, TextFrame) and self._seen_start_frame: + self._aggregation += frame.text + else: + await self.push_frame(frame, direction) diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index e3dfe0bce..521918b69 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -311,8 +311,15 @@ def __create_push_task(self): self.__push_frame_task = self.get_event_loop().create_task(self.__push_frame_task_handler()) async def __cancel_push_task(self): - self.__push_frame_task.cancel() - await self.__push_frame_task + try: + self.__push_frame_task.cancel() + await self.__push_frame_task + except asyncio.CancelledError: + # TODO(aleix: Investigate why this is really needed. So far, this is + # necessary because of how pytest works. If a task is cancelled, + # pytest will know the task has been cancelled even if + # `asyncio.CancelledError` is handled internally in the task. + pass async def __push_frame_task_handler(self): running = True diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/processors/__init__.py b/tests/processors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/processors/aggregators/__init__.py b/tests/processors/aggregators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/processors/aggregators/test_llm_response.py b/tests/processors/aggregators/test_llm_response.py new file mode 100644 index 000000000..ba161329b --- /dev/null +++ b/tests/processors/aggregators/test_llm_response.py @@ -0,0 +1,370 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest + +from pipecat.frames.frames import ( + BotInterruptionFrame, + InterimTranscriptionFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + StartInterruptionFrame, + StopInterruptionFrame, + TextFrame, + TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantContextAggregator, + LLMFullResponseAggregator, + LLMUserContextAggregator, +) +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame, +) +from tests.utils import run_test + + +class TestLLMUserContextAggregator(unittest.IsolatedAsyncioTestCase): + # S E -> + async def test_s_e(self): + """S E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S T E -> T + async def test_s_t_e(self): + """S T E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame("Hello", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S I T E -> T + async def test_s_i_t_e(self): + """S I T E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + TranscriptionFrame("This is a test", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S I E T -> T + async def test_s_i_e_t(self): + """S I E T case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S I E I T -> T + async def test_s_i_e_i_t(self): + """S I E I T case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + InterimTranscriptionFrame("This is", "", ""), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S E T -> T + async def test_s_e_t(self): + """S E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S E I T -> T + async def test_s_e_i_t(self): + """S E I T case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S T1 I E S T2 E -> "T1 T2" + async def test_s_t1_i_e_s_t2_e(self): + """S T1 I E S T2 E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + InterimTranscriptionFrame("", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame("T2", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_returned_frames + ) + assert received_down[-1].context.messages[-1]["content"] == "T1 T2" + + # S I E T1 I T2 -> T1 Interruption T2 + async def test_s_i_e_t1_i_t2(self): + """S I E T1 I T2 case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + InterimTranscriptionFrame("", "", ""), + TranscriptionFrame("T2", "", ""), + ] + expected_down_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + OpenAILLMContextFrame, + ] + expected_up_frames = [ + BotInterruptionFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_down_frames, expected_up_frames + ) + assert received_down[-1].context.messages[-2]["content"] == "T1" + assert received_down[-1].context.messages[-1]["content"] == "T2" + + # S T1 E T2 -> T1 Interruption T2 + async def test_s_t1_e_t2(self): + """S T1 E T2 case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("T2", "", ""), + ] + expected_down_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + OpenAILLMContextFrame, + ] + expected_up_frames = [ + BotInterruptionFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_down_frames, expected_up_frames + ) + assert received_down[-1].context.messages[-2]["content"] == "T1" + assert received_down[-1].context.messages[-1]["content"] == "T2" + + # S E T1 T2 -> T1 Interruption T2 + async def test_s_e_t1_t2(self): + """S E T1 T2 case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + TranscriptionFrame("T2", "", ""), + ] + expected_down_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + OpenAILLMContextFrame, + ] + expected_up_frames = [ + BotInterruptionFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_down_frames, expected_up_frames + ) + assert received_down[-1].context.messages[-2]["content"] == "T1" + assert received_down[-1].context.messages[-1]["content"] == "T2" + + +class TestLLMAssistantContextAggregator(unittest.IsolatedAsyncioTestCase): + # S T E -> T + async def test_s_t_e(self): + """S T E case""" + context_aggregator = LLMAssistantContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame("Hello this is Pipecat speaking!"), + TextFrame("How are you?"), + LLMFullResponseEndFrame(), + ] + expected_returned_frames = [ + LLMFullResponseStartFrame, + OpenAILLMContextFrame, + LLMFullResponseEndFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_returned_frames + ) + assert ( + received_down[-2].context.messages[-1]["content"] + == "Hello this is Pipecat speaking! How are you?" + ) + + +class TestLLMFullResponseAggregator(unittest.IsolatedAsyncioTestCase): + # S T E -> T + async def test_s_t_e(self): + """S T E case""" + response_aggregator = LLMFullResponseAggregator() + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame("Hello "), + TextFrame("this "), + TextFrame("is "), + TextFrame("Pipecat!"), + LLMFullResponseEndFrame(), + ] + expected_returned_frames = [ + LLMFullResponseStartFrame, + TextFrame, + LLMFullResponseEndFrame, + ] + (received_down, _) = await run_test( + response_aggregator, frames_to_send, expected_returned_frames + ) + assert received_down[-2].text == "Hello this is Pipecat!" diff --git a/tests/processors/frameworks/__init__.py b/tests/processors/frameworks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_langchain.py b/tests/processors/frameworks/test_langchain.py similarity index 100% rename from tests/test_langchain.py rename to tests/processors/frameworks/test_langchain.py diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_ai_services.py b/tests/services/test_ai_services.py similarity index 100% rename from tests/test_ai_services.py rename to tests/services/test_ai_services.py diff --git a/tests/test_aggregators.py b/tests/skipped/test_aggregators.py similarity index 100% rename from tests/test_aggregators.py rename to tests/skipped/test_aggregators.py diff --git a/tests/test_daily_transport_service.py b/tests/skipped/test_daily_transport_service.py similarity index 100% rename from tests/test_daily_transport_service.py rename to tests/skipped/test_daily_transport_service.py diff --git a/tests/test_openai_tts.py b/tests/skipped/test_openai_tts.py similarity index 100% rename from tests/test_openai_tts.py rename to tests/skipped/test_openai_tts.py diff --git a/tests/test_pipeline.py b/tests/skipped/test_pipeline.py similarity index 100% rename from tests/test_pipeline.py rename to tests/skipped/test_pipeline.py diff --git a/tests/test_protobuf_serializer.py b/tests/skipped/test_protobuf_serializer.py similarity index 100% rename from tests/test_protobuf_serializer.py rename to tests/skipped/test_protobuf_serializer.py diff --git a/tests/test_websocket_transport.py b/tests/skipped/test_websocket_transport.py similarity index 100% rename from tests/test_websocket_transport.py rename to tests/skipped/test_websocket_transport.py diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..cb3df0c8c --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,96 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +from dataclasses import dataclass +from typing import List, Tuple + +from pipecat.clocks.system_clock import SystemClock +from pipecat.frames.frames import ( + ControlFrame, + Frame, + StartFrame, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + + +@dataclass +class EndTestFrame(ControlFrame): + pass + + +class QueuedFrameProcessor(FrameProcessor): + def __init__(self, queue: asyncio.Queue, ignore_start: bool = True): + super().__init__() + self._queue = queue + self._ignore_start = ignore_start + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if self._ignore_start and isinstance(frame, StartFrame): + return + await self._queue.put(frame) + + +async def run_test( + processor: FrameProcessor, + frames_to_send: List[Frame], + expected_down_frames: List[type], + expected_up_frames: List[type] = [], +) -> Tuple[List[Frame], List[Frame]]: + received_up = asyncio.Queue() + received_down = asyncio.Queue() + up_processor = QueuedFrameProcessor(received_up) + down_processor = QueuedFrameProcessor(received_down) + + up_processor.link(processor) + processor.link(down_processor) + + await processor.queue_frame(StartFrame(clock=SystemClock())) + + for frame in frames_to_send: + await processor.process_frame(frame, FrameDirection.DOWNSTREAM) + + await processor.queue_frame(EndTestFrame()) + await processor.queue_frame(EndTestFrame(), FrameDirection.UPSTREAM) + + # + # Down frames + # + received_down_frames: List[Frame] = [] + running = True + while running: + frame = await received_down.get() + running = not isinstance(frame, EndTestFrame) + if running: + received_down_frames.append(frame) + + print("received DOWN frames =", received_down_frames) + + assert len(received_down_frames) == len(expected_down_frames) + + for real, expected in zip(received_down_frames, expected_down_frames): + assert isinstance(real, expected) + + # + # Up frames + # + received_up_frames: List[Frame] = [] + running = True + while running: + frame = await received_up.get() + running = not isinstance(frame, EndTestFrame) + if running: + received_up_frames.append(frame) + + print("received UP frames =", received_up_frames) + + assert len(received_up_frames) == len(expected_up_frames) + + for real, expected in zip(received_up_frames, expected_up_frames): + assert isinstance(real, expected) + + return (received_down_frames, received_up_frames)