Skip to content

Commit

Permalink
Add VoIP announce (home-assistant#136781)
Browse files Browse the repository at this point in the history
* Implement async_announce for VoIP

* Add tests

* Add network to voip dependencies
  • Loading branch information
synesthesiam authored and zxdavb committed Jan 29, 2025
1 parent 7348f1b commit 0299cfe
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 14 deletions.
140 changes: 128 additions & 12 deletions homeassistant/components/voip/assist_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,29 @@
import io
import logging
from pathlib import Path
import socket
import time
from typing import TYPE_CHECKING, Any, Final
import wave

from voip_utils import RtpDatagramProtocol
from voip_utils import SIP_PORT, RtpDatagramProtocol
from voip_utils.sip import SipEndpoint, get_sip_endpoint

from homeassistant.components import tts
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
from homeassistant.components.assist_satellite import (
AssistSatelliteAnnouncement,
AssistSatelliteConfiguration,
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
AssistSatelliteEntityFeature,
)
from homeassistant.components.network import async_get_source_ip
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback

from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
from .const import CHANNELS, CONF_SIP_PORT, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
from .devices import VoIPDevice
from .entity import VoIPEntity

Expand All @@ -34,6 +40,9 @@
_LOGGER = logging.getLogger(__name__)

_PIPELINE_TIMEOUT_SEC: Final = 30
_ANNOUNCEMENT_BEFORE_DELAY: Final = 0.5
_ANNOUNCEMENT_AFTER_DELAY: Final = 1.0
_ANNOUNCEMENT_HANGUP_SEC: Final = 0.5


class Tones(IntFlag):
Expand Down Expand Up @@ -80,6 +89,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
_attr_translation_key = "assist_satellite"
_attr_name = None
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE

def __init__(
self,
Expand All @@ -105,6 +115,12 @@ def __init__(
self._tones = tones
self._processing_tone_done = asyncio.Event()

self._announcement: AssistSatelliteAnnouncement | None = None
self._announcement_done = asyncio.Event()
self._check_announcement_ended_task: asyncio.Task | None = None
self._last_chunk_time: float | None = None
self._rtp_port: int | None = None

@property
def pipeline_entity_id(self) -> str | None:
"""Return the entity ID of the pipeline to use for the next conversation."""
Expand Down Expand Up @@ -149,25 +165,108 @@ async def async_set_configuration(
"""Set the current satellite configuration."""
raise NotImplementedError

async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
"""Announce media on the satellite.
Plays announcement in a loop, blocking until the caller hangs up.
"""
self._announcement_done.clear()

if self._rtp_port is None:
# Choose random port for RTP
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", 0))
_rtp_ip, self._rtp_port = sock.getsockname()
sock.close()

# HA SIP server
source_ip = await async_get_source_ip(self.hass)
sip_port = self.config_entry.options.get(CONF_SIP_PORT, SIP_PORT)
source_endpoint = get_sip_endpoint(host=source_ip, port=sip_port)

try:
# VoIP ID is SIP header
destination_endpoint = SipEndpoint(self.voip_device.voip_id)
except ValueError:
# VoIP ID is IP address
destination_endpoint = get_sip_endpoint(
host=self.voip_device.voip_id, port=SIP_PORT
)

self._announcement = announcement

# Make the call
self.hass.data[DOMAIN].protocol.outgoing_call(
source=source_endpoint,
destination=destination_endpoint,
rtp_port=self._rtp_port,
)

await self._announcement_done.wait()

async def _check_announcement_ended(self) -> None:
"""Continuously checks if an audio chunk was received within a time limit.
If not, the caller is presumed to have hung up and the announcement is ended.
"""
while self._announcement is not None:
if (self._last_chunk_time is not None) and (
(time.monotonic() - self._last_chunk_time) > _ANNOUNCEMENT_HANGUP_SEC
):
# Caller hung up
self._announcement = None
self._announcement_done.set()
self._check_announcement_ended_task = None
_LOGGER.debug("Announcement ended")
break

await asyncio.sleep(_ANNOUNCEMENT_HANGUP_SEC / 2)

# -------------------------------------------------------------------------
# VoIP
# -------------------------------------------------------------------------

def on_chunk(self, audio_bytes: bytes) -> None:
"""Handle raw audio chunk."""
if self._run_pipeline_task is None:
# Run pipeline until voice command finishes, then start over
self._clear_audio_queue()
self._tts_done.clear()
self._last_chunk_time = time.monotonic()

if self._announcement is None:
# Pipeline with STT
if self._run_pipeline_task is None:
# Run pipeline until voice command finishes, then start over
self._clear_audio_queue()
self._tts_done.clear()
self._run_pipeline_task = (
self.config_entry.async_create_background_task(
self.hass,
self._run_pipeline(),
"voip_pipeline_run",
)
)

self._audio_queue.put_nowait(audio_bytes)
elif self._run_pipeline_task is None:
# Announcement only
if self._check_announcement_ended_task is None:
# Check if caller hung up
self._check_announcement_ended_task = (
self.config_entry.async_create_background_task(
self.hass,
self._check_announcement_ended(),
"voip_announcement_ended",
)
)

# Play announcement (will repeat)
self._run_pipeline_task = self.config_entry.async_create_background_task(
self.hass,
self._run_pipeline(),
"voip_pipeline_run",
self._play_announcement(self._announcement),
"voip_play_announcement",
)

self._audio_queue.put_nowait(audio_bytes)

async def _run_pipeline(self) -> None:
"""Run a pipeline with STT input and TTS output."""
_LOGGER.debug("Starting pipeline")

self.async_set_context(Context(user_id=self.config_entry.data["user"]))
Expand Down Expand Up @@ -209,6 +308,23 @@ async def stt_stream():
self._run_pipeline_task = None
_LOGGER.debug("Pipeline finished")

async def _play_announcement(
self, announcement: AssistSatelliteAnnouncement
) -> None:
"""Play an announcement once."""
_LOGGER.debug("Playing announcement")

try:
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
await self._send_tts(announcement.original_media_id, wait_for_tone=False)
await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY)
except Exception:
_LOGGER.exception("Unexpected error while playing announcement")
raise
finally:
self._run_pipeline_task = None
_LOGGER.debug("Announcement finished")

def _clear_audio_queue(self) -> None:
"""Ensure audio queue is empty."""
while not self._audio_queue.empty():
Expand Down Expand Up @@ -239,7 +355,7 @@ def on_pipeline_event(self, event: PipelineEvent) -> None:
self._pipeline_had_error = True
_LOGGER.warning(event)

async def _send_tts(self, media_id: str) -> None:
async def _send_tts(self, media_id: str, wait_for_tone: bool = True) -> None:
"""Send TTS audio to caller via RTP."""
try:
if self.transport is None:
Expand All @@ -253,7 +369,7 @@ async def _send_tts(self, media_id: str) -> None:
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")

if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
if wait_for_tone and ((self._tones & Tones.PROCESSING) == Tones.PROCESSING):
# Don't overlap TTS and processing beep
_LOGGER.debug("Waiting for processing tone")
await self._processing_tone_done.wait()
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/voip/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"name": "Voice over IP",
"codeowners": ["@balloob", "@synesthesiam"],
"config_flow": true,
"dependencies": ["assist_pipeline", "assist_satellite"],
"dependencies": ["assist_pipeline", "assist_satellite", "network"],
"documentation": "https://www.home-assistant.io/integrations/voip",
"iot_class": "local_push",
"quality_scale": "internal",
Expand Down
99 changes: 98 additions & 1 deletion tests/components/voip/test_voip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# pylint: disable-next=hass-component-root-import
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
from homeassistant.components.voip import HassVoipDatagramProtocol
from homeassistant.components.voip import DOMAIN, HassVoipDatagramProtocol
from homeassistant.components.voip.assist_satellite import Tones, VoipAssistSatellite
from homeassistant.components.voip.devices import VoIPDevice, VoIPDevices
from homeassistant.components.voip.voip import PreRecordMessageProtocol, make_protocol
Expand Down Expand Up @@ -844,3 +844,100 @@ async def async_send_audio(audio_bytes: bytes, **kwargs):

assert sum(played_audio_bytes) > 0
assert played_audio_bytes == snapshot()


@pytest.mark.usefixtures("socket_enabled")
async def test_announce(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test announcement."""
assert await async_setup_component(hass, "voip", {})

satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
assert (
satellite.supported_features
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
)

announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement",
media_id=_MEDIA_ID,
original_media_id=_MEDIA_ID,
media_id_source="tts",
)

# Protocol has already been mocked, but "outgoing_call" is not async
mock_protocol: AsyncMock = hass.data[DOMAIN].protocol
mock_protocol.outgoing_call = Mock()

with (
patch(
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
) as mock_send_tts,
):
satellite.transport = Mock()
announce_task = hass.async_create_background_task(
satellite.async_announce(announcement), "voip_announce"
)
await asyncio.sleep(0)
mock_protocol.outgoing_call.assert_called_once()

# Trigger announcement
satellite.on_chunk(bytes(_ONE_SECOND))
await announce_task

mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)


@pytest.mark.usefixtures("socket_enabled")
async def test_voip_id_is_ip_address(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test announcement when VoIP is an IP address instead of a SIP header."""
assert await async_setup_component(hass, "voip", {})

satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
assert (
satellite.supported_features
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
)

announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement",
media_id=_MEDIA_ID,
original_media_id=_MEDIA_ID,
media_id_source="tts",
)

# Protocol has already been mocked, but "outgoing_call" is not async
mock_protocol: AsyncMock = hass.data[DOMAIN].protocol
mock_protocol.outgoing_call = Mock()

with (
patch.object(voip_device, "voip_id", "192.168.68.10"),
patch(
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
) as mock_send_tts,
):
satellite.transport = Mock()
announce_task = hass.async_create_background_task(
satellite.async_announce(announcement), "voip_announce"
)
await asyncio.sleep(0)
mock_protocol.outgoing_call.assert_called_once()
assert (
mock_protocol.outgoing_call.call_args.kwargs["destination"].host
== "192.168.68.10"
)

# Trigger announcement
satellite.on_chunk(bytes(_ONE_SECOND))
await announce_task

mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)

0 comments on commit 0299cfe

Please sign in to comment.