diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index b806efad4..628f7369b 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -49,4 +49,4 @@ jobs: - name: Test with pytest run: | source .venv/bin/activate - pytest --ignore-glob="*to_be_updated*" --ignore-glob=*pipeline_source* src tests + pytest diff --git a/pyproject.toml b/pyproject.toml index 1d8b84679..a67c9a8e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,10 @@ whisper = [ "faster-whisper~=1.1.0" ] where = ["src"] [tool.pytest.ini_options] +addopts = "--verbose --disable-warnings" +testpaths = ["tests"] pythonpath = ["src"] +asyncio_default_fixture_loop_scope = "function" [tool.setuptools_scm] local_scheme = "no-local-version" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/processors/__init__.py b/tests/processors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/processors/aggregators/__init__.py b/tests/processors/aggregators/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/processors/aggregators/test_llm_response.py b/tests/processors/aggregators/test_llm_response.py new file mode 100644 index 000000000..ba161329b --- /dev/null +++ b/tests/processors/aggregators/test_llm_response.py @@ -0,0 +1,370 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import unittest + +from pipecat.frames.frames import ( + BotInterruptionFrame, + InterimTranscriptionFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + StartInterruptionFrame, + StopInterruptionFrame, + TextFrame, + TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, +) +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantContextAggregator, + LLMFullResponseAggregator, + LLMUserContextAggregator, +) +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame, +) +from tests.utils import run_test + + +class TestLLMUserContextAggregator(unittest.IsolatedAsyncioTestCase): + # S E -> + async def test_s_e(self): + """S E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S T E -> T + async def test_s_t_e(self): + """S T E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame("Hello", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S I T E -> T + async def test_s_i_t_e(self): + """S I T E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + TranscriptionFrame("This is a test", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S I E T -> T + async def test_s_i_e_t(self): + """S I E T case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S I E I T -> T + async def test_s_i_e_i_t(self): + """S I E I T case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + InterimTranscriptionFrame("This is", "", ""), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S E T -> T + async def test_s_e_t(self): + """S E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S E I T -> T + async def test_s_e_i_t(self): + """S E I T case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + InterimTranscriptionFrame("This", "", ""), + TranscriptionFrame("This is a test", "", ""), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + await run_test(context_aggregator, frames_to_send, expected_returned_frames) + + # S T1 I E S T2 E -> "T1 T2" + async def test_s_t1_i_e_s_t2_e(self): + """S T1 I E S T2 E case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + InterimTranscriptionFrame("", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame("T2", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + ] + expected_returned_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_returned_frames + ) + assert received_down[-1].context.messages[-1]["content"] == "T1 T2" + + # S I E T1 I T2 -> T1 Interruption T2 + async def test_s_i_e_t1_i_t2(self): + """S I E T1 I T2 case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + InterimTranscriptionFrame("", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + InterimTranscriptionFrame("", "", ""), + TranscriptionFrame("T2", "", ""), + ] + expected_down_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + OpenAILLMContextFrame, + ] + expected_up_frames = [ + BotInterruptionFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_down_frames, expected_up_frames + ) + assert received_down[-1].context.messages[-2]["content"] == "T1" + assert received_down[-1].context.messages[-1]["content"] == "T2" + + # S T1 E T2 -> T1 Interruption T2 + async def test_s_t1_e_t2(self): + """S T1 E T2 case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("T2", "", ""), + ] + expected_down_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + OpenAILLMContextFrame, + ] + expected_up_frames = [ + BotInterruptionFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_down_frames, expected_up_frames + ) + assert received_down[-1].context.messages[-2]["content"] == "T1" + assert received_down[-1].context.messages[-1]["content"] == "T2" + + # S E T1 T2 -> T1 Interruption T2 + async def test_s_e_t1_t2(self): + """S E T1 T2 case""" + context_aggregator = LLMUserContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + StartInterruptionFrame(), + UserStartedSpeakingFrame(), + StopInterruptionFrame(), + UserStoppedSpeakingFrame(), + TranscriptionFrame("T1", "", ""), + TranscriptionFrame("T2", "", ""), + ] + expected_down_frames = [ + StartInterruptionFrame, + UserStartedSpeakingFrame, + StopInterruptionFrame, + UserStoppedSpeakingFrame, + OpenAILLMContextFrame, + OpenAILLMContextFrame, + ] + expected_up_frames = [ + BotInterruptionFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_down_frames, expected_up_frames + ) + assert received_down[-1].context.messages[-2]["content"] == "T1" + assert received_down[-1].context.messages[-1]["content"] == "T2" + + +class TestLLMAssistantContextAggregator(unittest.IsolatedAsyncioTestCase): + # S T E -> T + async def test_s_t_e(self): + """S T E case""" + context_aggregator = LLMAssistantContextAggregator( + OpenAILLMContext(messages=[{"role": "", "content": ""}]) + ) + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame("Hello this is Pipecat speaking!"), + TextFrame("How are you?"), + LLMFullResponseEndFrame(), + ] + expected_returned_frames = [ + LLMFullResponseStartFrame, + OpenAILLMContextFrame, + LLMFullResponseEndFrame, + ] + (received_down, _) = await run_test( + context_aggregator, frames_to_send, expected_returned_frames + ) + assert ( + received_down[-2].context.messages[-1]["content"] + == "Hello this is Pipecat speaking! How are you?" + ) + + +class TestLLMFullResponseAggregator(unittest.IsolatedAsyncioTestCase): + # S T E -> T + async def test_s_t_e(self): + """S T E case""" + response_aggregator = LLMFullResponseAggregator() + frames_to_send = [ + LLMFullResponseStartFrame(), + TextFrame("Hello "), + TextFrame("this "), + TextFrame("is "), + TextFrame("Pipecat!"), + LLMFullResponseEndFrame(), + ] + expected_returned_frames = [ + LLMFullResponseStartFrame, + TextFrame, + LLMFullResponseEndFrame, + ] + (received_down, _) = await run_test( + response_aggregator, frames_to_send, expected_returned_frames + ) + assert received_down[-2].text == "Hello this is Pipecat!" diff --git a/tests/processors/frameworks/__init__.py b/tests/processors/frameworks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_langchain.py b/tests/processors/frameworks/test_langchain.py similarity index 100% rename from tests/test_langchain.py rename to tests/processors/frameworks/test_langchain.py diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_ai_services.py b/tests/services/test_ai_services.py similarity index 100% rename from tests/test_ai_services.py rename to tests/services/test_ai_services.py diff --git a/tests/test_aggregators.py b/tests/skipped/test_aggregators.py similarity index 100% rename from tests/test_aggregators.py rename to tests/skipped/test_aggregators.py diff --git a/tests/test_daily_transport_service.py b/tests/skipped/test_daily_transport_service.py similarity index 100% rename from tests/test_daily_transport_service.py rename to tests/skipped/test_daily_transport_service.py diff --git a/tests/test_openai_tts.py b/tests/skipped/test_openai_tts.py similarity index 100% rename from tests/test_openai_tts.py rename to tests/skipped/test_openai_tts.py diff --git a/tests/test_pipeline.py b/tests/skipped/test_pipeline.py similarity index 100% rename from tests/test_pipeline.py rename to tests/skipped/test_pipeline.py diff --git a/tests/test_protobuf_serializer.py b/tests/skipped/test_protobuf_serializer.py similarity index 100% rename from tests/test_protobuf_serializer.py rename to tests/skipped/test_protobuf_serializer.py diff --git a/tests/test_websocket_transport.py b/tests/skipped/test_websocket_transport.py similarity index 100% rename from tests/test_websocket_transport.py rename to tests/skipped/test_websocket_transport.py diff --git a/tests/test_LLM_user_context_aggregator.py b/tests/test_LLM_user_context_aggregator.py deleted file mode 100644 index 3498c1e58..000000000 --- a/tests/test_LLM_user_context_aggregator.py +++ /dev/null @@ -1,276 +0,0 @@ -# tests/test_custom_user_context.py - -"""Tests for CustomLLMUserContextAggregator""" - -import asyncio -import unittest - -from dataclasses import dataclass -from typing import List - -from pipecat.clocks.system_clock import SystemClock -from pipecat.frames.frames import ( - ControlFrame, - Frame, - StartFrame, - TranscriptionFrame, - InterimTranscriptionFrame, - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, -) -from pipecat.processors.aggregators.llm_response import LLMUserContextAggregator -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame -from pipecat.processors.frame_processor import FrameDirection, FrameProcessor - -# Note that UserStartedSpeakingFrame always come with StartInterruptionFrame -# and UserStoppedSpeakingFrame always come with StopInterruptionFrame -# S E -> None -# S T E -> T -# S I T E -> T -# S I E T -> T -# S I E I T -> T -# S E T -> T -# S E I T -> T -# S T1 I E S T2 E -> (T1 T2) -# S I E T1 I T2 -> T1 Interruption T2 -# S T1 E T2 -> T1 Interruption T2 -# S E T1 B T2 -> T1 Bot Interruption T2 -# S E T1 T2 -> T1 Interruption T2 - - -@dataclass -class EndTestFrame(ControlFrame): - pass - - -class QueuedFrameProcessor(FrameProcessor): - def __init__(self, queue: asyncio.Queue, ignore_start: bool = True): - super().__init__() - self._queue = queue - self._ignore_start = ignore_start - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - if self._ignore_start and isinstance(frame, StartFrame): - return - await self._queue.put(frame) - - -async def make_test( - frames_to_send: List[Frame], expected_returned_frames: List[type] -) -> List[Frame]: - context_aggregator = LLMUserContextAggregator( - OpenAILLMContext(messages=[{"role": "", "content": ""}]) - ) - - received = asyncio.Queue() - test_processor = QueuedFrameProcessor(received) - context_aggregator.link(test_processor) - - await context_aggregator.queue_frame(StartFrame(clock=SystemClock())) - for frame in frames_to_send: - await context_aggregator.process_frame(frame, direction=FrameDirection.DOWNSTREAM) - await context_aggregator.queue_frame(EndTestFrame()) - - received_frames: List[Frame] = [] - running = True - while running: - frame = await received.get() - running = not isinstance(frame, EndTestFrame) - if running: - received_frames.append(frame) - - assert len(received_frames) == len(expected_returned_frames) - for real, expected in zip(received_frames, expected_returned_frames): - assert isinstance(real, expected) - return received_frames - - -class TestFrameProcessing(unittest.IsolatedAsyncioTestCase): - # S E -> - async def test_s_e(self): - """S E case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - UserStoppedSpeakingFrame(), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S T E -> T - async def test_s_t_e(self): - """S T E case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - TranscriptionFrame("Hello", "", ""), - UserStoppedSpeakingFrame(), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S I T E -> T - async def test_s_i_t_e(self): - """S I T E case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - InterimTranscriptionFrame("This", "", ""), - TranscriptionFrame("This is a test", "", ""), - UserStoppedSpeakingFrame(), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S I E T -> T - async def test_s_i_e_t(self): - """S I E T case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - InterimTranscriptionFrame("This", "", ""), - UserStoppedSpeakingFrame(), - TranscriptionFrame("This is a test", "", ""), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S I E I T -> T - async def test_s_i_e_i_t(self): - """S I E I T case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - InterimTranscriptionFrame("This", "", ""), - UserStoppedSpeakingFrame(), - InterimTranscriptionFrame("This is", "", ""), - TranscriptionFrame("This is a test", "", ""), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S E T -> T - async def test_s_e_t(self): - """S E case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - UserStoppedSpeakingFrame(), - TranscriptionFrame("This is a test", "", ""), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S E I T -> T - async def test_s_e_i_t(self): - """S E I T case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - UserStoppedSpeakingFrame(), - InterimTranscriptionFrame("This", "", ""), - TranscriptionFrame("This is a test", "", ""), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - await make_test(frames_to_send, expected_returned_frames) - - # S T1 I E S T2 E -> "T1 T2" - async def test_s_t1_i_e_s_t2_e(self): - """S T1 I E S T2 E case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - TranscriptionFrame("T1", "", ""), - InterimTranscriptionFrame("", "", ""), - UserStoppedSpeakingFrame(), - UserStartedSpeakingFrame(), - TranscriptionFrame("T2", "", ""), - UserStoppedSpeakingFrame(), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - ] - result = await make_test(frames_to_send, expected_returned_frames) - assert result[-1].context.messages[-1]["content"] == "T1 T2" - - # S I E T1 I T2 -> T1 Interruption T2 - async def test_s_i_e_t1_i_t2(self): - """S I E T1 I T2 case""" - frames_to_send = [ - UserStartedSpeakingFrame(), - InterimTranscriptionFrame("", "", ""), - UserStoppedSpeakingFrame(), - TranscriptionFrame("T1", "", ""), - InterimTranscriptionFrame("", "", ""), - TranscriptionFrame("T2", "", ""), - ] - expected_returned_frames = [ - UserStartedSpeakingFrame, - UserStoppedSpeakingFrame, - OpenAILLMContextFrame, - OpenAILLMContextFrame, - ] - result = await make_test(frames_to_send, expected_returned_frames) - assert result[-2].context.messages[-2]["content"] == "T1" - assert result[-1].context.messages[-1]["content"] == "T2" - - # # S T1 E T2 -> T1 Interruption T2 - # async def test_s_t1_e_t2(self): - # """S T1 E T2 case""" - # frames_to_send = [ - # UserStartedSpeakingFrame(), - # TranscriptionFrame("T1", "", ""), - # UserStoppedSpeakingFrame(), - # TranscriptionFrame("T2", "", ""), - # ] - # expected_returned_frames = [ - # UserStartedSpeakingFrame, - # UserStoppedSpeakingFrame, - # OpenAILLMContextFrame, - # OpenAILLMContextFrame, - # ] - # result = await make_test(frames_to_send, expected_returned_frames) - # assert result[-1].context.messages[-1]["content"] == " T1 T2" - - # # S E T1 T2 -> T1 Interruption T2 - # async def test_s_e_t1_t2(self): - # """S E T1 T2 case""" - # frames_to_send = [ - # UserStartedSpeakingFrame(), - # UserStoppedSpeakingFrame(), - # TranscriptionFrame("T1", "", ""), - # TranscriptionFrame("T2", "", ""), - # ] - # expected_returned_frames = [ - # UserStartedSpeakingFrame, - # UserStoppedSpeakingFrame, - # OpenAILLMContextFrame, - # OpenAILLMContextFrame, - # ] - # result = await make_test(frames_to_send, expected_returned_frames) - # assert result[-1].context.messages[-1]["content"] == " T1 T2" diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 000000000..cb3df0c8c --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,96 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +from dataclasses import dataclass +from typing import List, Tuple + +from pipecat.clocks.system_clock import SystemClock +from pipecat.frames.frames import ( + ControlFrame, + Frame, + StartFrame, +) +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor + + +@dataclass +class EndTestFrame(ControlFrame): + pass + + +class QueuedFrameProcessor(FrameProcessor): + def __init__(self, queue: asyncio.Queue, ignore_start: bool = True): + super().__init__() + self._queue = queue + self._ignore_start = ignore_start + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if self._ignore_start and isinstance(frame, StartFrame): + return + await self._queue.put(frame) + + +async def run_test( + processor: FrameProcessor, + frames_to_send: List[Frame], + expected_down_frames: List[type], + expected_up_frames: List[type] = [], +) -> Tuple[List[Frame], List[Frame]]: + received_up = asyncio.Queue() + received_down = asyncio.Queue() + up_processor = QueuedFrameProcessor(received_up) + down_processor = QueuedFrameProcessor(received_down) + + up_processor.link(processor) + processor.link(down_processor) + + await processor.queue_frame(StartFrame(clock=SystemClock())) + + for frame in frames_to_send: + await processor.process_frame(frame, FrameDirection.DOWNSTREAM) + + await processor.queue_frame(EndTestFrame()) + await processor.queue_frame(EndTestFrame(), FrameDirection.UPSTREAM) + + # + # Down frames + # + received_down_frames: List[Frame] = [] + running = True + while running: + frame = await received_down.get() + running = not isinstance(frame, EndTestFrame) + if running: + received_down_frames.append(frame) + + print("received DOWN frames =", received_down_frames) + + assert len(received_down_frames) == len(expected_down_frames) + + for real, expected in zip(received_down_frames, expected_down_frames): + assert isinstance(real, expected) + + # + # Up frames + # + received_up_frames: List[Frame] = [] + running = True + while running: + frame = await received_up.get() + running = not isinstance(frame, EndTestFrame) + if running: + received_up_frames.append(frame) + + print("received UP frames =", received_up_frames) + + assert len(received_up_frames) == len(expected_up_frames) + + for real, expected in zip(received_up_frames, expected_up_frames): + assert isinstance(real, expected) + + return (received_down_frames, received_up_frames)