Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push to talk #82

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions examples/2mic_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
"""Controls the LEDs on the ReSpeaker 2mic HAT."""
"""Controls the LEDs and GPIO Button on the ReSpeaker 2mic HAT."""
import argparse
import asyncio
import logging
Expand All @@ -10,6 +10,8 @@

import gpiozero
import spidev
import RPi.GPIO as GPIO

from wyoming.asr import Transcript
from wyoming.event import Event
from wyoming.satellite import (
Expand All @@ -21,12 +23,13 @@
)
from wyoming.server import AsyncEventHandler, AsyncServer
from wyoming.vad import VoiceStarted
from wyoming.wake import Detection
from wyoming.wake import Detect, Detection

_LOGGER = logging.getLogger()

NUM_LEDS = 3
LEDS_GPIO = 12
BUTTON_GPIO = 17
RGB_MAP = {
"rgb": [3, 2, 1],
"rbg": [3, 1, 2],
Expand All @@ -41,26 +44,30 @@ async def main() -> None:
"""Main entry point."""
parser = argparse.ArgumentParser()
parser.add_argument("--uri", required=True, help="unix:// or tcp://")
#
parser.add_argument("--debug", action="store_true", help="Log DEBUG messages")
parser.add_argument("--log-format", default=logging.BASIC_FORMAT, help="Format for log messages")
args = parser.parse_args()

logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
_LOGGER.debug(args)
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO, format=args.log_format)

_LOGGER.info("Ready")
_LOGGER.debug(args)
_LOGGER.info("Event service Ready")

# Turn on power to LEDs
led_power = gpiozero.LED(LEDS_GPIO, active_high=False)
led_power.on()

leds = APA102(num_led=NUM_LEDS)

# GPIO Button
GPIO.setmode(GPIO.BCM)
GPIO.setup(BUTTON_GPIO, GPIO.IN)

# Start server
server = AsyncServer.from_uri(args.uri)

try:
await server.run(partial(LEDsEventHandler, args, leds))
await server.run(partial(EventHandler, args, leds))
except KeyboardInterrupt:
pass
finally:
Expand All @@ -78,7 +85,7 @@ async def main() -> None:
_GREEN = (0, 255, 0)


class LEDsEventHandler(AsyncEventHandler):
class EventHandler(AsyncEventHandler):
"""Event handler for clients."""

def __init__(
Expand All @@ -93,13 +100,25 @@ def __init__(
self.cli_args = cli_args
self.client_id = str(time.monotonic_ns())
self.leds = leds
self.detect_name = None

GPIO.add_event_detect(BUTTON_GPIO, GPIO.RISING, callback=self.button_callback)

_LOGGER.debug("Client connected: %s", self.client_id)


def button_callback(self, button_pin):
_LOGGER.debug("Button pressed #%s", button_pin)
asyncio.run(self.write_event(Detection(name=self.detect_name, timestamp=time.monotonic_ns()).event()))


async def handle_event(self, event: Event) -> bool:
_LOGGER.debug(event)

if StreamingStarted.is_type(event.type):
if Detect.is_type(event.type):
detect = Detect.from_event(event)
self.detect_name = detect.names[0]
elif StreamingStarted.is_type(event.type):
self.color(_YELLOW)
elif Detection.is_type(event.type):
self.color(_BLUE)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_wake_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(self) -> None:
self.wake_event = asyncio.Event()

async def read_event(self) -> Optional[Event]:
# Input only
return None
# Sends a detection event
return Detection().event()

async def write_event(self, event: Event) -> None:
if Detection.is_type(event.type):
Expand Down
64 changes: 57 additions & 7 deletions wyoming_satellite/satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ async def event_from_server(self, event: Event) -> None:
await self.trigger_detect()
elif Detection.is_type(event.type):
# Wake word detected
_LOGGER.debug("Wake word detected")
_LOGGER.debug("Remote wake word detected")
await self.trigger_detection(Detection.from_event(event))
elif VoiceStarted.is_type(event.type):
# STT start
Expand Down Expand Up @@ -350,8 +350,6 @@ async def _connect_to_services(self) -> None:
self._event_task_proc(), name="event"
)

_LOGGER.info("Connected to services")

async def _disconnect_from_services(self) -> None:
"""Disconnects from running services."""
if self._mic_task is not None:
Expand Down Expand Up @@ -720,6 +718,7 @@ async def _disconnect() -> None:
await asyncio.sleep(self.settings.wake.reconnect_seconds)
continue

_LOGGER.debug("Event received from wake service")
await self.event_from_wake(event)

except asyncio.CancelledError:
Expand Down Expand Up @@ -832,22 +831,69 @@ async def _disconnect() -> None:
if self._event_queue is None:
self._event_queue = asyncio.Queue()

event = await self._event_queue.get()

if event_client is None:
event_client = self._make_event_client()
assert event_client is not None
await event_client.connect()
_LOGGER.debug("Connected to event service")

await event_client.write_event(event)
# Reset
from_client_task = None
to_client_task = None
pending = set()
self._event_queue = asyncio.Queue()

# Inform event service of the wake word handled by this satellite instance
await self.forward_event(Detect(names=self.settings.wake.names).event())

# Read/write in "parallel"
if to_client_task is None:
# From satellite to event service
to_client_task = asyncio.create_task(
self._event_queue.get(), name="event_to_client"
)
pending.add(to_client_task)

if from_client_task is None:
# From event service to satellite
from_client_task = asyncio.create_task(
event_client.read_event(), name="event_from_client"
)
pending.add(from_client_task)

done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)

if to_client_task in done:
# Forward event to event service for handling
assert to_client_task is not None
event = to_client_task.result()
to_client_task = None
await event_client.write_event(event)

if from_client_task in done:
# Event from event service (button for detection)
assert from_client_task is not None
event = from_client_task.result()
from_client_task = None

if event is None:
_LOGGER.warning("Event service disconnected")
await _disconnect()
event_client = None # reconnect
await asyncio.sleep(self.settings.event.reconnect_seconds)
continue

_LOGGER.debug("Event received from event service")
if Detection.is_type(event.type):
await self.event_from_wake(event)
except asyncio.CancelledError:
break
except Exception:
_LOGGER.exception("Unexpected error in event read task")
await _disconnect()
event_client = None # reconnect
self._event_queue = None
await asyncio.sleep(self.settings.event.reconnect_seconds)

await _disconnect()
Expand All @@ -861,6 +907,7 @@ class AlwaysStreamingSatellite(SatelliteBase):

def __init__(self, settings: SatelliteSettings) -> None:
super().__init__(settings)
_LOGGER.debug("Initiating an AlwaysStreamingSatellite")
self.is_streaming = False

if settings.vad.enabled:
Expand Down Expand Up @@ -927,6 +974,7 @@ def __init__(self, settings: SatelliteSettings) -> None:
raise ValueError("VAD is not enabled")

super().__init__(settings)
_LOGGER.debug("Initiating a VadStreamingSatellite")
self.is_streaming = False
self.vad = SileroVad(
threshold=settings.vad.threshold, trigger_level=settings.vad.trigger_level
Expand Down Expand Up @@ -1088,6 +1136,7 @@ def __init__(self, settings: SatelliteSettings) -> None:
raise ValueError("Local wake word detection is not enabled")

super().__init__(settings)
_LOGGER.debug("Initiating a WakeStreamingSatellite")
self.is_streaming = False

# Timestamp in the future when the refractory period is over (set with
Expand Down Expand Up @@ -1202,6 +1251,7 @@ async def event_from_wake(self, event: Event) -> None:
return

if Detection.is_type(event.type):
_LOGGER.debug("Detection triggered from event")
detection = Detection.from_event(event)

# Check refractory period to avoid multiple back-to-back detections
Expand Down