From 70652bbea2285621271b5b5fb9938924a7684fbf Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Sun, 18 Feb 2024 17:27:59 -0600 Subject: [PATCH] Add more tests --- .github/workflows/test.yml | 2 + requirements.txt | 2 +- setup.py | 2 - tests/test_satellite.py | 198 +++++++++++++++++++++++++++++ tests/test_wake_streaming.py | 4 +- tox.ini | 14 ++ wyoming_satellite/event_handler.py | 5 +- wyoming_satellite/satellite.py | 38 +++++- 8 files changed, 257 insertions(+), 8 deletions(-) create mode 100644 tests/test_satellite.py create mode 100644 tox.ini diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bad7ffa..dc27a08 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,5 +1,7 @@ +--- name: test +# yamllint disable-line rule:truthy on: workflow_dispatch: pull_request: diff --git a/requirements.txt b/requirements.txt index 66d5da8..1eb7bdf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -wyoming==1.5.2 +wyoming==1.5.3 zeroconf==0.88.0 pyring-buffer==1.0.0 diff --git a/setup.py b/setup.py index d8422f3..8aba9bd 100644 --- a/setup.py +++ b/setup.py @@ -52,8 +52,6 @@ def get_requirements(req_path: Path) -> List[str]: "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/tests/test_satellite.py b/tests/test_satellite.py new file mode 100644 index 0000000..528dd5f --- /dev/null +++ b/tests/test_satellite.py @@ -0,0 +1,198 @@ +import asyncio +import io +import logging +from collections.abc import Iterable +from pathlib import Path +from typing import Final, Optional +from unittest.mock import patch + +import pytest +from wyoming.asr import Transcript +from wyoming.audio import AudioChunk +from wyoming.client import AsyncClient +from wyoming.event import Event, async_read_event +from wyoming.pipeline import PipelineStage, RunPipeline +from wyoming.satellite import RunSatellite, StreamingStarted, StreamingStopped +from wyoming.wake import Detection + +from wyoming_satellite import ( + EventSettings, + MicSettings, + SatelliteSettings, + WakeSettings, + WakeStreamingSatellite, +) + +from .shared import AUDIO_CHUNK + +_LOGGER = logging.getLogger() + +TIMEOUT: Final = 1 + + +class MicClient(AsyncClient): + def __init__(self) -> None: + super().__init__() + + async def read_event(self) -> Optional[Event]: + await asyncio.sleep(AUDIO_CHUNK.seconds) + return AUDIO_CHUNK.event() + + async def write_event(self, event: Event) -> None: + # Output only + pass + + +class WakeClient(AsyncClient): + def __init__(self) -> None: + super().__init__() + self._event_ready = asyncio.Event() + self._event: Optional[Event] = None + self._detected: bool = False + + async def read_event(self) -> Optional[Event]: + await self._event_ready.wait() + self._event_ready.clear() + return self._event + + async def write_event(self, event: Event) -> None: + if AudioChunk.is_type(event.type): + if not self._detected: + self._detected = True + self._event = Detection().event() + self._event_ready.set() + + +class EventClient(AsyncClient): + def __init__(self) -> None: + super().__init__() + self.detection = asyncio.Event() + self.streaming_started = asyncio.Event() + self.streaming_stopped = asyncio.Event() + + async def read_event(self) -> Optional[Event]: + # Input only + return None + + async def write_event(self, event: Event) -> None: + if Detection.is_type(event.type): + self.detection.set() + elif StreamingStarted.is_type(event.type): + self.streaming_started.set() + elif StreamingStopped.is_type(event.type): + self.streaming_stopped.set() + + +class FakeStreamReaderWriter: + def __init__(self) -> None: + self._undrained_data = bytes() + self._value = bytes() + self._data_ready = asyncio.Event() + + def write(self, data: bytes) -> None: + self._undrained_data += data + + def writelines(self, data: Iterable[bytes]) -> None: + for line in data: + self.write(line) + + async def drain(self) -> None: + self._value += self._undrained_data + self._undrained_data = bytes() + self._data_ready.set() + self._data_ready.clear() + + async def readline(self) -> bytes: + while b"\n" not in self._value: + await self._data_ready.wait() + + with io.BytesIO(self._value) as value_io: + data = value_io.readline() + self._value = self._value[len(data) :] + return data + + async def readexactly(self, n: int) -> bytes: + while len(self._value) < n: + await self._data_ready.wait() + + data = self._value[:n] + self._value = self._value[n:] + return data + + +@pytest.mark.asyncio +async def test_satellite_and_server(tmp_path: Path) -> None: + mic_client = MicClient() + wake_client = WakeClient() + event_client = EventClient() + + with patch( + "wyoming_satellite.satellite.SatelliteBase._make_mic_client", + return_value=mic_client, + ), patch( + "wyoming_satellite.satellite.SatelliteBase._make_wake_client", + return_value=wake_client, + ), patch( + "wyoming_satellite.satellite.SatelliteBase._make_event_client", + return_value=event_client, + ): + satellite = WakeStreamingSatellite( + SatelliteSettings( + mic=MicSettings(uri="test"), + wake=WakeSettings(uri="test"), + event=EventSettings(uri="test"), + ) + ) + + # Fake server connection + server_io = FakeStreamReaderWriter() + await satellite.set_server("test", server_io) # type: ignore + + async def event_from_satellite() -> Optional[Event]: + return await async_read_event(server_io) + + satellite_task = asyncio.create_task(satellite.run(), name="satellite") + await satellite.event_from_server(RunSatellite().event()) + + # Trigger detection + event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT) + assert event is not None + assert Detection.is_type(event.type), event + + # Pipeline should start + event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT) + assert event is not None + assert RunPipeline.is_type(event.type), event + run_pipeline = RunPipeline.from_event(event) + assert run_pipeline.start_stage == PipelineStage.ASR + + # No TTS + assert run_pipeline.end_stage == PipelineStage.HANDLE + + # Event service should have received detection + await asyncio.wait_for(event_client.detection.wait(), timeout=TIMEOUT) + + # Server should be receiving audio now + assert satellite.is_streaming, "Not streaming" + for _ in range(5): + event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT) + assert event is not None + assert AudioChunk.is_type(event.type) + + # Event service should have received streaming start + await asyncio.wait_for(event_client.streaming_started.wait(), timeout=TIMEOUT) + + # Send transcript + await satellite.event_from_server(Transcript(text="test").event()) + + # Wait for streaming to stop + while satellite.is_streaming: + event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT) + assert event is not None + assert AudioChunk.is_type(event.type) + + # Event service should have received streaming stop + await asyncio.wait_for(event_client.streaming_stopped.wait(), timeout=TIMEOUT) + + await satellite.stop() + await satellite_task diff --git a/tests/test_wake_streaming.py b/tests/test_wake_streaming.py index ec98ab9..0977509 100644 --- a/tests/test_wake_streaming.py +++ b/tests/test_wake_streaming.py @@ -101,8 +101,8 @@ async def test_multiple_wakeups(tmp_path: Path) -> None: await satellite.event_from_server(Transcript("test").event()) # Should not trigger again within refractory period (default: 5 sec) - # with pytest.raises(asyncio.TimeoutError): - # await asyncio.wait_for(event_client.wake_event.wait(), timeout=0.15) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(event_client.wake_event.wait(), timeout=0.15) await satellite.stop() await satellite_task diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..80dd236 --- /dev/null +++ b/tox.ini @@ -0,0 +1,14 @@ +[tox] +env_list = + py{39,310,311} +minversion = 4.12.1 + +[testenv] +description = run the tests with pytest +package = wheel +wheel_build_env = .pkg +deps = + pytest>=7,<8 + pytest-asyncio<1 +commands = + pytest {tty:--color=yes} {posargs} diff --git a/wyoming_satellite/event_handler.py b/wyoming_satellite/event_handler.py index f68179b..58f046b 100644 --- a/wyoming_satellite/event_handler.py +++ b/wyoming_satellite/event_handler.py @@ -26,7 +26,7 @@ def __init__( super().__init__(*args, **kwargs) self.cli_args = cli_args - self.wyoming_info_event = wyoming_info.event() + self.wyoming_info = wyoming_info self.client_id = str(time.monotonic_ns()) self.satellite = satellite @@ -35,7 +35,8 @@ def __init__( async def handle_event(self, event: Event) -> bool: """Handle events from the server.""" if Describe.is_type(event.type): - await self.write_event(self.wyoming_info_event) + await self.satellite.update_info(self.wyoming_info) + await self.write_event(self.wyoming_info.event()) return True if self.satellite.server_id is None: diff --git a/wyoming_satellite/satellite.py b/wyoming_satellite/satellite.py index 3f2a19a..6102254 100644 --- a/wyoming_satellite/satellite.py +++ b/wyoming_satellite/satellite.py @@ -15,6 +15,7 @@ from wyoming.client import AsyncClient from wyoming.error import Error from wyoming.event import Event, async_write_event +from wyoming.info import Describe, Info from wyoming.mic import MicProcessAsyncClient from wyoming.ping import Ping, Pong from wyoming.pipeline import PipelineStage, RunPipeline @@ -46,6 +47,7 @@ _PONG_TIMEOUT: Final = 5 _PING_SEND_DELAY: Final = 2 +_WAKE_INFO_TIMEOUT: Final = 2 class State(Enum): @@ -898,6 +900,13 @@ async def _disconnect() -> None: await _disconnect() + # ------------------------------------------------------------------------- + # Info + # ------------------------------------------------------------------------- + + async def update_info(self, info: Info) -> None: + pass + # ----------------------------------------------------------------------------- @@ -1150,6 +1159,9 @@ def __init__(self, settings: SatelliteSettings) -> None: self._is_paused = False + self._wake_info: Optional[Info] = None + self._wake_info_ready = asyncio.Event() + async def event_from_server(self, event: Event) -> None: # Only check event types once is_run_satellite = False @@ -1243,8 +1255,13 @@ async def event_from_mic( await self.event_to_wake(event) async def event_from_wake(self, event: Event) -> None: + if Info.is_type(event.type): + self._wake_info = Info.from_event(event) + self._wake_info_ready.set() + return + if self.is_streaming or (self.server_id is None): - # Not streaming or no server connected + # Not detecting or no server connected return if Detection.is_type(event.type): @@ -1281,6 +1298,9 @@ async def event_from_wake(self, event: Event) -> None: # No refractory period self.refractory_timestamp.pop(detection.name, None) + # Forward to the server + await self.event_to_server(event) + # Match detected wake word name with pipeline name pipeline_name: Optional[str] = None if self.settings.wake.names: @@ -1294,3 +1314,19 @@ async def event_from_wake(self, event: Event) -> None: await self.forward_event(event) # forward to event service await self.trigger_detection(Detection.from_event(event)) await self.trigger_streaming_start() + + async def update_info(self, info: Info) -> None: + self._wake_info = None + self._wake_info_ready.clear() + await self.event_to_wake(Describe().event()) + + try: + await asyncio.wait_for( + self._wake_info_ready.wait(), timeout=_WAKE_INFO_TIMEOUT + ) + + if self._wake_info is not None: + # Update wake info only + info.wake = self._wake_info.wake + except asyncio.TimeoutError: + _LOGGER.warning("Failed to get info from wake service")