Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VoIP announce #136781

Merged
merged 3 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 128 additions & 12 deletions homeassistant/components/voip/assist_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,29 @@
import io
import logging
from pathlib import Path
import socket
import time
from typing import TYPE_CHECKING, Any, Final
import wave

from voip_utils import RtpDatagramProtocol
from voip_utils import SIP_PORT, RtpDatagramProtocol
from voip_utils.sip import SipEndpoint, get_sip_endpoint

from homeassistant.components import tts
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
from homeassistant.components.assist_satellite import (
AssistSatelliteAnnouncement,
AssistSatelliteConfiguration,
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
AssistSatelliteEntityFeature,
)
from homeassistant.components.network import async_get_source_ip
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback

from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
from .const import CHANNELS, CONF_SIP_PORT, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
from .devices import VoIPDevice
from .entity import VoIPEntity

Expand All @@ -34,6 +40,9 @@
_LOGGER = logging.getLogger(__name__)

_PIPELINE_TIMEOUT_SEC: Final = 30
_ANNOUNCEMENT_BEFORE_DELAY: Final = 0.5
_ANNOUNCEMENT_AFTER_DELAY: Final = 1.0
_ANNOUNCEMENT_HANGUP_SEC: Final = 0.5


class Tones(IntFlag):
Expand Down Expand Up @@ -80,6 +89,7 @@
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
_attr_translation_key = "assist_satellite"
_attr_name = None
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE

def __init__(
self,
Expand All @@ -105,6 +115,12 @@
self._tones = tones
self._processing_tone_done = asyncio.Event()

self._announcement: AssistSatelliteAnnouncement | None = None
self._announcement_done = asyncio.Event()
self._check_announcement_ended_task: asyncio.Task | None = None
self._last_chunk_time: float | None = None
self._rtp_port: int | None = None

@property
def pipeline_entity_id(self) -> str | None:
"""Return the entity ID of the pipeline to use for the next conversation."""
Expand Down Expand Up @@ -149,25 +165,108 @@
"""Set the current satellite configuration."""
raise NotImplementedError

async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
"""Announce media on the satellite.

Plays announcement in a loop, blocking until the caller hangs up.
"""
self._announcement_done.clear()

if self._rtp_port is None:
# Choose random port for RTP
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", 0))
_rtp_ip, self._rtp_port = sock.getsockname()
sock.close()

# HA SIP server
source_ip = await async_get_source_ip(self.hass)
sip_port = self.config_entry.options.get(CONF_SIP_PORT, SIP_PORT)
source_endpoint = get_sip_endpoint(host=source_ip, port=sip_port)

try:
# VoIP ID is SIP header
destination_endpoint = SipEndpoint(self.voip_device.voip_id)
except ValueError:
# VoIP ID is IP address
destination_endpoint = get_sip_endpoint(
host=self.voip_device.voip_id, port=SIP_PORT
)

self._announcement = announcement

# Make the call
self.hass.data[DOMAIN].protocol.outgoing_call(
source=source_endpoint,
destination=destination_endpoint,
rtp_port=self._rtp_port,
)

await self._announcement_done.wait()

async def _check_announcement_ended(self) -> None:
"""Continuously checks if an audio chunk was received within a time limit.

If not, the caller is presumed to have hung up and the announcement is ended.
"""
while self._announcement is not None:
if (self._last_chunk_time is not None) and (
(time.monotonic() - self._last_chunk_time) > _ANNOUNCEMENT_HANGUP_SEC
):
# Caller hung up
self._announcement = None
self._announcement_done.set()
self._check_announcement_ended_task = None
_LOGGER.debug("Announcement ended")
break

await asyncio.sleep(_ANNOUNCEMENT_HANGUP_SEC / 2)

# -------------------------------------------------------------------------
# VoIP
# -------------------------------------------------------------------------

def on_chunk(self, audio_bytes: bytes) -> None:
"""Handle raw audio chunk."""
if self._run_pipeline_task is None:
# Run pipeline until voice command finishes, then start over
self._clear_audio_queue()
self._tts_done.clear()
self._last_chunk_time = time.monotonic()

if self._announcement is None:
# Pipeline with STT
if self._run_pipeline_task is None:
# Run pipeline until voice command finishes, then start over
self._clear_audio_queue()
self._tts_done.clear()
self._run_pipeline_task = (
self.config_entry.async_create_background_task(
self.hass,
self._run_pipeline(),
"voip_pipeline_run",
)
)

Check warning on line 246 in homeassistant/components/voip/assist_satellite.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/components/voip/assist_satellite.py#L246

Added line #L246 was not covered by tests

self._audio_queue.put_nowait(audio_bytes)
elif self._run_pipeline_task is None:
# Announcement only
if self._check_announcement_ended_task is None:
# Check if caller hung up
self._check_announcement_ended_task = (
self.config_entry.async_create_background_task(
self.hass,
self._check_announcement_ended(),
"voip_announcement_ended",
)
)

# Play announcement (will repeat)
self._run_pipeline_task = self.config_entry.async_create_background_task(
self.hass,
self._run_pipeline(),
"voip_pipeline_run",
self._play_announcement(self._announcement),
"voip_play_announcement",
)

self._audio_queue.put_nowait(audio_bytes)

async def _run_pipeline(self) -> None:
"""Run a pipeline with STT input and TTS output."""
_LOGGER.debug("Starting pipeline")

self.async_set_context(Context(user_id=self.config_entry.data["user"]))
Expand Down Expand Up @@ -209,6 +308,23 @@
self._run_pipeline_task = None
_LOGGER.debug("Pipeline finished")

async def _play_announcement(
self, announcement: AssistSatelliteAnnouncement
) -> None:
"""Play an announcement once."""
_LOGGER.debug("Playing announcement")

try:
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
await self._send_tts(announcement.original_media_id, wait_for_tone=False)
await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY)
except Exception:
_LOGGER.exception("Unexpected error while playing announcement")
raise

Check warning on line 323 in homeassistant/components/voip/assist_satellite.py

View check run for this annotation

Codecov / codecov/patch

homeassistant/components/voip/assist_satellite.py#L323

Added line #L323 was not covered by tests
finally:
self._run_pipeline_task = None
_LOGGER.debug("Announcement finished")

def _clear_audio_queue(self) -> None:
"""Ensure audio queue is empty."""
while not self._audio_queue.empty():
Expand Down Expand Up @@ -239,7 +355,7 @@
self._pipeline_had_error = True
_LOGGER.warning(event)

async def _send_tts(self, media_id: str) -> None:
async def _send_tts(self, media_id: str, wait_for_tone: bool = True) -> None:
"""Send TTS audio to caller via RTP."""
try:
if self.transport is None:
Expand All @@ -253,7 +369,7 @@
if extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}")

if (self._tones & Tones.PROCESSING) == Tones.PROCESSING:
if wait_for_tone and ((self._tones & Tones.PROCESSING) == Tones.PROCESSING):
# Don't overlap TTS and processing beep
_LOGGER.debug("Waiting for processing tone")
await self._processing_tone_done.wait()
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/voip/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"name": "Voice over IP",
"codeowners": ["@balloob", "@synesthesiam"],
"config_flow": true,
"dependencies": ["assist_pipeline", "assist_satellite"],
"dependencies": ["assist_pipeline", "assist_satellite", "network"],
"documentation": "https://www.home-assistant.io/integrations/voip",
"iot_class": "local_push",
"quality_scale": "internal",
Expand Down
99 changes: 98 additions & 1 deletion tests/components/voip/test_voip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# pylint: disable-next=hass-component-root-import
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
from homeassistant.components.voip import HassVoipDatagramProtocol
from homeassistant.components.voip import DOMAIN, HassVoipDatagramProtocol
from homeassistant.components.voip.assist_satellite import Tones, VoipAssistSatellite
from homeassistant.components.voip.devices import VoIPDevice, VoIPDevices
from homeassistant.components.voip.voip import PreRecordMessageProtocol, make_protocol
Expand Down Expand Up @@ -844,3 +844,100 @@ async def async_send_audio(audio_bytes: bytes, **kwargs):

assert sum(played_audio_bytes) > 0
assert played_audio_bytes == snapshot()


@pytest.mark.usefixtures("socket_enabled")
async def test_announce(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test announcement."""
assert await async_setup_component(hass, "voip", {})

satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
assert (
satellite.supported_features
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
)

announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement",
media_id=_MEDIA_ID,
original_media_id=_MEDIA_ID,
media_id_source="tts",
)

# Protocol has already been mocked, but "outgoing_call" is not async
mock_protocol: AsyncMock = hass.data[DOMAIN].protocol
mock_protocol.outgoing_call = Mock()

with (
patch(
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
) as mock_send_tts,
):
satellite.transport = Mock()
announce_task = hass.async_create_background_task(
satellite.async_announce(announcement), "voip_announce"
)
await asyncio.sleep(0)
mock_protocol.outgoing_call.assert_called_once()

# Trigger announcement
satellite.on_chunk(bytes(_ONE_SECOND))
await announce_task

mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)


@pytest.mark.usefixtures("socket_enabled")
async def test_voip_id_is_ip_address(
hass: HomeAssistant,
voip_devices: VoIPDevices,
voip_device: VoIPDevice,
) -> None:
"""Test announcement when VoIP is an IP address instead of a SIP header."""
assert await async_setup_component(hass, "voip", {})

satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
assert isinstance(satellite, VoipAssistSatellite)
assert (
satellite.supported_features
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
)

announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement",
media_id=_MEDIA_ID,
original_media_id=_MEDIA_ID,
media_id_source="tts",
)

# Protocol has already been mocked, but "outgoing_call" is not async
mock_protocol: AsyncMock = hass.data[DOMAIN].protocol
mock_protocol.outgoing_call = Mock()

with (
patch.object(voip_device, "voip_id", "192.168.68.10"),
patch(
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
) as mock_send_tts,
):
satellite.transport = Mock()
announce_task = hass.async_create_background_task(
satellite.async_announce(announcement), "voip_announce"
)
await asyncio.sleep(0)
mock_protocol.outgoing_call.assert_called_once()
assert (
mock_protocol.outgoing_call.call_args.kwargs["destination"].host
== "192.168.68.10"
)

# Trigger announcement
satellite.on_chunk(bytes(_ONE_SECOND))
await announce_task

mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)