From 36353ff03ab96160f77f0b1901a516bad4598133 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 19 Feb 2024 14:20:57 -0600 Subject: [PATCH] Flesh out test --- tests/shared.py | 60 +++++++++++++- tests/test_satellite.py | 146 ++++++++++++++++++++--------------- tests/test_wake_streaming.py | 13 +--- 3 files changed, 145 insertions(+), 74 deletions(-) diff --git a/tests/shared.py b/tests/shared.py index 8d6e01d..2b47c37 100644 --- a/tests/shared.py +++ b/tests/shared.py @@ -1,6 +1,64 @@ """Shared code for Wyoming satellite tests.""" -from wyoming.audio import AudioChunk +import asyncio +import io +from collections.abc import Iterable +from typing import Optional + +from wyoming.audio import AudioChunk, AudioStart, AudioStop +from wyoming.client import AsyncClient +from wyoming.event import Event + +AUDIO_START = AudioStart(rate=16000, width=2, channels=1) +AUDIO_STOP = AudioStop() AUDIO_CHUNK = AudioChunk( rate=16000, width=2, channels=1, audio=bytes([255] * 960) # 30ms ) + + +class FakeStreamReaderWriter: + def __init__(self) -> None: + self._undrained_data = bytes() + self._value = bytes() + self._data_ready = asyncio.Event() + + def write(self, data: bytes) -> None: + self._undrained_data += data + + def writelines(self, data: Iterable[bytes]) -> None: + for line in data: + self.write(line) + + async def drain(self) -> None: + self._value += self._undrained_data + self._undrained_data = bytes() + self._data_ready.set() + self._data_ready.clear() + + async def readline(self) -> bytes: + while b"\n" not in self._value: + await self._data_ready.wait() + + with io.BytesIO(self._value) as value_io: + data = value_io.readline() + self._value = self._value[len(data) :] + return data + + async def readexactly(self, n: int) -> bytes: + while len(self._value) < n: + await self._data_ready.wait() + + data = self._value[:n] + self._value = self._value[n:] + return data + + +class MicClient(AsyncClient): + async def read_event(self) -> Optional[Event]: + # Send 30ms of audio every 30ms + await asyncio.sleep(AUDIO_CHUNK.seconds) + return AUDIO_CHUNK.event() + + async def write_event(self, event: Event) -> None: + # Output only + pass diff --git a/tests/test_satellite.py b/tests/test_satellite.py index 528dd5f..a6686ab 100644 --- a/tests/test_satellite.py +++ b/tests/test_satellite.py @@ -1,48 +1,40 @@ import asyncio -import io import logging -from collections.abc import Iterable -from pathlib import Path from typing import Final, Optional from unittest.mock import patch import pytest from wyoming.asr import Transcript -from wyoming.audio import AudioChunk +from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.client import AsyncClient from wyoming.event import Event, async_read_event from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.satellite import RunSatellite, StreamingStarted, StreamingStopped +from wyoming.tts import Synthesize from wyoming.wake import Detection from wyoming_satellite import ( EventSettings, MicSettings, SatelliteSettings, + SndSettings, WakeSettings, WakeStreamingSatellite, ) -from .shared import AUDIO_CHUNK +from .shared import ( + AUDIO_CHUNK, + AUDIO_START, + AUDIO_STOP, + FakeStreamReaderWriter, + MicClient, +) _LOGGER = logging.getLogger() TIMEOUT: Final = 1 -class MicClient(AsyncClient): - def __init__(self) -> None: - super().__init__() - - async def read_event(self) -> Optional[Event]: - await asyncio.sleep(AUDIO_CHUNK.seconds) - return AUDIO_CHUNK.event() - - async def write_event(self, event: Event) -> None: - # Output only - pass - - class WakeClient(AsyncClient): def __init__(self) -> None: super().__init__() @@ -63,12 +55,40 @@ async def write_event(self, event: Event) -> None: self._event_ready.set() +class SndClient(AsyncClient): + def __init__(self) -> None: + super().__init__() + self.synthesize = asyncio.Event() + self.audio_start = asyncio.Event() + self.audio_chunk = asyncio.Event() + self.audio_stop = asyncio.Event() + + async def read_event(self) -> Optional[Event]: + # Input only + pass + + async def write_event(self, event: Event) -> None: + if AudioChunk.is_type(event.type): + self.audio_chunk.set() + elif Synthesize.is_type(event.type): + self.synthesize.set() + elif AudioStart.is_type(event.type): + self.audio_start.set() + elif AudioStop.is_type(event.type): + self.audio_stop.set() + + class EventClient(AsyncClient): def __init__(self) -> None: super().__init__() self.detection = asyncio.Event() self.streaming_started = asyncio.Event() self.streaming_stopped = asyncio.Event() + self.transcript = asyncio.Event() + self.synthesize = asyncio.Event() + self.audio_start = asyncio.Event() + self.audio_chunk = asyncio.Event() + self.audio_stop = asyncio.Event() async def read_event(self) -> Optional[Event]: # Input only @@ -81,54 +101,34 @@ async def write_event(self, event: Event) -> None: self.streaming_started.set() elif StreamingStopped.is_type(event.type): self.streaming_stopped.set() + elif Transcript.is_type(event.type): + self.transcript.set() + elif Synthesize.is_type(event.type): + self.synthesize.set() + elif AudioChunk.is_type(event.type): + self.audio_chunk.set() + elif AudioStart.is_type(event.type): + self.audio_start.set() + elif AudioStop.is_type(event.type): + self.audio_stop.set() -class FakeStreamReaderWriter: - def __init__(self) -> None: - self._undrained_data = bytes() - self._value = bytes() - self._data_ready = asyncio.Event() - - def write(self, data: bytes) -> None: - self._undrained_data += data - - def writelines(self, data: Iterable[bytes]) -> None: - for line in data: - self.write(line) - - async def drain(self) -> None: - self._value += self._undrained_data - self._undrained_data = bytes() - self._data_ready.set() - self._data_ready.clear() - - async def readline(self) -> bytes: - while b"\n" not in self._value: - await self._data_ready.wait() - - with io.BytesIO(self._value) as value_io: - data = value_io.readline() - self._value = self._value[len(data) :] - return data - - async def readexactly(self, n: int) -> bytes: - while len(self._value) < n: - await self._data_ready.wait() - - data = self._value[:n] - self._value = self._value[n:] - return data +# ----------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_satellite_and_server(tmp_path: Path) -> None: +async def test_wake_satellite() -> None: mic_client = MicClient() + snd_client = SndClient() wake_client = WakeClient() event_client = EventClient() with patch( "wyoming_satellite.satellite.SatelliteBase._make_mic_client", return_value=mic_client, + ), patch( + "wyoming_satellite.satellite.SatelliteBase._make_snd_client", + return_value=snd_client, ), patch( "wyoming_satellite.satellite.SatelliteBase._make_wake_client", return_value=wake_client, @@ -139,19 +139,22 @@ async def test_satellite_and_server(tmp_path: Path) -> None: satellite = WakeStreamingSatellite( SatelliteSettings( mic=MicSettings(uri="test"), + snd=SndSettings(uri="test"), wake=WakeSettings(uri="test"), event=EventSettings(uri="test"), ) ) - # Fake server connection - server_io = FakeStreamReaderWriter() - await satellite.set_server("test", server_io) # type: ignore - async def event_from_satellite() -> Optional[Event]: return await async_read_event(server_io) satellite_task = asyncio.create_task(satellite.run(), name="satellite") + + # Fake server connection + server_io = FakeStreamReaderWriter() + await satellite.set_server("test", server_io) # type: ignore + + # Start satellite await satellite.event_from_server(RunSatellite().event()) # Trigger detection @@ -165,9 +168,7 @@ async def event_from_satellite() -> Optional[Event]: assert RunPipeline.is_type(event.type), event run_pipeline = RunPipeline.from_event(event) assert run_pipeline.start_stage == PipelineStage.ASR - - # No TTS - assert run_pipeline.end_stage == PipelineStage.HANDLE + assert run_pipeline.end_stage == PipelineStage.TTS # Event service should have received detection await asyncio.wait_for(event_client.detection.wait(), timeout=TIMEOUT) @@ -185,6 +186,9 @@ async def event_from_satellite() -> Optional[Event]: # Send transcript await satellite.event_from_server(Transcript(text="test").event()) + # Event service should have received transcript + await asyncio.wait_for(event_client.transcript.wait(), timeout=TIMEOUT) + # Wait for streaming to stop while satellite.is_streaming: event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT) @@ -194,5 +198,25 @@ async def event_from_satellite() -> Optional[Event]: # Event service should have received streaming stop await asyncio.wait_for(event_client.streaming_stopped.wait(), timeout=TIMEOUT) + # Fake a TTS response + await satellite.event_from_server(Synthesize(text="test").event()) + + # Event service should have received synthesize + await asyncio.wait_for(event_client.synthesize.wait(), timeout=TIMEOUT) + + # Audio start, chunk, stop + await satellite.event_from_server(AUDIO_START.event()) + await asyncio.wait_for(snd_client.audio_start.wait(), timeout=TIMEOUT) + await asyncio.wait_for(event_client.audio_start.wait(), timeout=TIMEOUT) + + # Event service does not get audio chunks, just start/stop + await satellite.event_from_server(AUDIO_CHUNK.event()) + await asyncio.wait_for(snd_client.audio_chunk.wait(), timeout=TIMEOUT) + + await satellite.event_from_server(AUDIO_STOP.event()) + await asyncio.wait_for(snd_client.audio_stop.wait(), timeout=TIMEOUT) + await asyncio.wait_for(event_client.audio_stop.wait(), timeout=TIMEOUT) + + # Stop satellite await satellite.stop() await satellite_task diff --git a/tests/test_wake_streaming.py b/tests/test_wake_streaming.py index 0977509..2550ee5 100644 --- a/tests/test_wake_streaming.py +++ b/tests/test_wake_streaming.py @@ -20,22 +20,11 @@ WakeStreamingSatellite, ) -from .shared import AUDIO_CHUNK +from .shared import MicClient _LOGGER = logging.getLogger() -class MicClient(AsyncClient): - async def read_event(self) -> Optional[Event]: - # Send 30ms of audio every 30ms - await asyncio.sleep(AUDIO_CHUNK.seconds) - return AUDIO_CHUNK.event() - - async def write_event(self, event: Event) -> None: - # Output only - pass - - class WakeClient(AsyncClient): def __init__(self) -> None: super().__init__()