diff --git a/examples/2mic_service.py b/examples/2mic_service.py index ac58c4a..f6eba1e 100644 --- a/examples/2mic_service.py +++ b/examples/2mic_service.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Controls the LEDs on the ReSpeaker 2mic HAT.""" +"""Controls the LEDs and GPIO Button on the ReSpeaker 2mic HAT.""" import argparse import asyncio import logging @@ -10,6 +10,8 @@ import gpiozero import spidev +import RPi.GPIO as GPIO + from wyoming.asr import Transcript from wyoming.event import Event from wyoming.satellite import ( @@ -21,12 +23,13 @@ ) from wyoming.server import AsyncEventHandler, AsyncServer from wyoming.vad import VoiceStarted -from wyoming.wake import Detection +from wyoming.wake import Detect, Detection _LOGGER = logging.getLogger() NUM_LEDS = 3 LEDS_GPIO = 12 +BUTTON_GPIO = 17 RGB_MAP = { "rgb": [3, 2, 1], "rbg": [3, 1, 2], @@ -41,14 +44,14 @@ async def main() -> None: """Main entry point.""" parser = argparse.ArgumentParser() parser.add_argument("--uri", required=True, help="unix:// or tcp://") - # parser.add_argument("--debug", action="store_true", help="Log DEBUG messages") + parser.add_argument("--log-format", default=logging.BASIC_FORMAT, help="Format for log messages") args = parser.parse_args() - logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO) - _LOGGER.debug(args) + logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO, format=args.log_format) - _LOGGER.info("Ready") + _LOGGER.debug(args) + _LOGGER.info("Event service Ready") # Turn on power to LEDs led_power = gpiozero.LED(LEDS_GPIO, active_high=False) @@ -56,11 +59,15 @@ async def main() -> None: leds = APA102(num_led=NUM_LEDS) + # GPIO Button + GPIO.setmode(GPIO.BCM) + GPIO.setup(BUTTON_GPIO, GPIO.IN) + # Start server server = AsyncServer.from_uri(args.uri) try: - await server.run(partial(LEDsEventHandler, args, leds)) + await server.run(partial(EventHandler, args, leds)) except KeyboardInterrupt: pass finally: @@ -78,7 +85,7 @@ async def main() -> None: _GREEN = (0, 255, 0) -class LEDsEventHandler(AsyncEventHandler): +class EventHandler(AsyncEventHandler): """Event handler for clients.""" def __init__( @@ -93,13 +100,25 @@ def __init__( self.cli_args = cli_args self.client_id = str(time.monotonic_ns()) self.leds = leds + self.detect_name = None + + GPIO.add_event_detect(BUTTON_GPIO, GPIO.RISING, callback=self.button_callback) _LOGGER.debug("Client connected: %s", self.client_id) + + def button_callback(self, button_pin): + _LOGGER.debug("Button pressed #%s", button_pin) + asyncio.run(self.write_event(Detection(name=self.detect_name, timestamp=time.monotonic_ns()).event())) + + async def handle_event(self, event: Event) -> bool: _LOGGER.debug(event) - if StreamingStarted.is_type(event.type): + if Detect.is_type(event.type): + detect = Detect.from_event(event) + self.detect_name = detect.names[0] + elif StreamingStarted.is_type(event.type): self.color(_YELLOW) elif Detection.is_type(event.type): self.color(_BLUE) diff --git a/tests/test_wake_streaming.py b/tests/test_wake_streaming.py index ec98ab9..a906f4f 100644 --- a/tests/test_wake_streaming.py +++ b/tests/test_wake_streaming.py @@ -56,8 +56,8 @@ def __init__(self) -> None: self.wake_event = asyncio.Event() async def read_event(self) -> Optional[Event]: - # Input only - return None + # Sends a detection event + return Detection().event() async def write_event(self, event: Event) -> None: if Detection.is_type(event.type): diff --git a/wyoming_satellite/satellite.py b/wyoming_satellite/satellite.py index 84b5115..3322778 100644 --- a/wyoming_satellite/satellite.py +++ b/wyoming_satellite/satellite.py @@ -261,7 +261,7 @@ async def event_from_server(self, event: Event) -> None: await self.trigger_detect() elif Detection.is_type(event.type): # Wake word detected - _LOGGER.debug("Wake word detected") + _LOGGER.debug("Remote wake word detected") await self.trigger_detection(Detection.from_event(event)) elif VoiceStarted.is_type(event.type): # STT start @@ -350,8 +350,6 @@ async def _connect_to_services(self) -> None: self._event_task_proc(), name="event" ) - _LOGGER.info("Connected to services") - async def _disconnect_from_services(self) -> None: """Disconnects from running services.""" if self._mic_task is not None: @@ -720,6 +718,7 @@ async def _disconnect() -> None: await asyncio.sleep(self.settings.wake.reconnect_seconds) continue + _LOGGER.debug("Event received from wake service") await self.event_from_wake(event) except asyncio.CancelledError: @@ -832,22 +831,69 @@ async def _disconnect() -> None: if self._event_queue is None: self._event_queue = asyncio.Queue() - event = await self._event_queue.get() - if event_client is None: event_client = self._make_event_client() assert event_client is not None await event_client.connect() _LOGGER.debug("Connected to event service") - await event_client.write_event(event) + # Reset + from_client_task = None + to_client_task = None + pending = set() + self._event_queue = asyncio.Queue() + + # Inform event service of the wake word handled by this satellite instance + await self.forward_event(Detect(names=self.settings.wake.names).event()) + + # Read/write in "parallel" + if to_client_task is None: + # From satellite to event service + to_client_task = asyncio.create_task( + self._event_queue.get(), name="event_to_client" + ) + pending.add(to_client_task) + + if from_client_task is None: + # From event service to satellite + from_client_task = asyncio.create_task( + event_client.read_event(), name="event_from_client" + ) + pending.add(from_client_task) + + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) + + if to_client_task in done: + # Forward event to event service for handling + assert to_client_task is not None + event = to_client_task.result() + to_client_task = None + await event_client.write_event(event) + + if from_client_task in done: + # Event from event service (button for detection) + assert from_client_task is not None + event = from_client_task.result() + from_client_task = None + + if event is None: + _LOGGER.warning("Event service disconnected") + await _disconnect() + event_client = None # reconnect + await asyncio.sleep(self.settings.event.reconnect_seconds) + continue + + _LOGGER.debug("Event received from event service") + if Detection.is_type(event.type): + await self.event_from_wake(event) except asyncio.CancelledError: break except Exception: _LOGGER.exception("Unexpected error in event read task") await _disconnect() event_client = None # reconnect - self._event_queue = None await asyncio.sleep(self.settings.event.reconnect_seconds) await _disconnect() @@ -861,6 +907,7 @@ class AlwaysStreamingSatellite(SatelliteBase): def __init__(self, settings: SatelliteSettings) -> None: super().__init__(settings) + _LOGGER.debug("Initiating an AlwaysStreamingSatellite") self.is_streaming = False if settings.vad.enabled: @@ -927,6 +974,7 @@ def __init__(self, settings: SatelliteSettings) -> None: raise ValueError("VAD is not enabled") super().__init__(settings) + _LOGGER.debug("Initiating a VadStreamingSatellite") self.is_streaming = False self.vad = SileroVad( threshold=settings.vad.threshold, trigger_level=settings.vad.trigger_level @@ -1088,6 +1136,7 @@ def __init__(self, settings: SatelliteSettings) -> None: raise ValueError("Local wake word detection is not enabled") super().__init__(settings) + _LOGGER.debug("Initiating a WakeStreamingSatellite") self.is_streaming = False # Timestamp in the future when the refractory period is over (set with @@ -1202,6 +1251,7 @@ async def event_from_wake(self, event: Event) -> None: return if Detection.is_type(event.type): + _LOGGER.debug("Detection triggered from event") detection = Detection.from_event(event) # Check refractory period to avoid multiple back-to-back detections