Add assist satellite entity component (#125351)

* Add assist_satellite

* Update homeassistant/components/assist_satellite/manifest.json

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>

* Update homeassistant/components/assist_satellite/manifest.json

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>

* Add platform constant

* Update Dockerfile

* Apply suggestions from code review

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Address comments

* Update docstring async_internal_announce

* Update CODEOWNERS

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Michael Hansen
2024-09-05 20:16:30 -05:00
committed by GitHub
parent c3921f2112
commit 60b0f0dc53
22 changed files with 1188 additions and 4 deletions

View File

@ -14,6 +14,7 @@ core: &core
base_platforms: &base_platforms
- homeassistant/components/air_quality/**
- homeassistant/components/alarm_control_panel/**
- homeassistant/components/assist_satellite/**
- homeassistant/components/binary_sensor/**
- homeassistant/components/button/**
- homeassistant/components/calendar/**

View File

@ -95,6 +95,7 @@ homeassistant.components.aruba.*
homeassistant.components.arwn.*
homeassistant.components.aseko_pool_live.*
homeassistant.components.assist_pipeline.*
homeassistant.components.assist_satellite.*
homeassistant.components.asuswrt.*
homeassistant.components.autarco.*
homeassistant.components.auth.*

View File

@ -143,6 +143,8 @@ build.json @home-assistant/supervisor
/tests/components/aseko_pool_live/ @milanmeu
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
/tests/components/assist_pipeline/ @balloob @synesthesiam
/homeassistant/components/assist_satellite/ @home-assistant/core @synesthesiam
/tests/components/assist_satellite/ @home-assistant/core @synesthesiam
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
/tests/components/asuswrt/ @kennedyshead @ollo69
/homeassistant/components/atag/ @MatsNL

View File

@ -17,6 +17,7 @@ from .const import (
DATA_LAST_WAKE_UP,
DOMAIN,
EVENT_RECORDING,
OPTION_PREFERRED,
SAMPLE_CHANNELS,
SAMPLE_RATE,
SAMPLE_WIDTH,
@ -58,6 +59,7 @@ __all__ = (
"PipelineNotFound",
"WakeWordSettings",
"EVENT_RECORDING",
"OPTION_PREFERRED",
"SAMPLES_PER_CHUNK",
"SAMPLE_RATE",
"SAMPLE_WIDTH",

View File

@ -22,3 +22,5 @@ SAMPLE_CHANNELS = 1 # mono
MS_PER_CHUNK = 10
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
OPTION_PREFERRED = "preferred"

View File

@ -9,12 +9,10 @@ from homeassistant.const import EntityCategory, Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import collection, entity_registry as er, restore_state
from .const import DOMAIN
from .const import DOMAIN, OPTION_PREFERRED
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
from .vad import VadSensitivity
OPTION_PREFERRED = "preferred"
@callback
def get_chosen_pipeline(

View File

@ -0,0 +1,65 @@
"""Base class for assist satellite entities."""
import logging
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN, AssistSatelliteEntityFeature
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
from .errors import SatelliteBusyError
from .websocket_api import async_register_websocket_api
__all__ = [
"DOMAIN",
"AssistSatelliteEntity",
"AssistSatelliteEntityDescription",
"AssistSatelliteEntityFeature",
"SatelliteBusyError",
]
_LOGGER = logging.getLogger(__name__)
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
component = hass.data[DOMAIN] = EntityComponent[AssistSatelliteEntity](
_LOGGER, DOMAIN, hass
)
await component.async_setup(config)
component.async_register_entity_service(
"announce",
vol.All(
cv.make_entity_service_schema(
{
vol.Optional("message"): str,
vol.Optional("media_id"): str,
}
),
cv.has_at_least_one_key("message", "media_id"),
),
"async_internal_announce",
[AssistSatelliteEntityFeature.ANNOUNCE],
)
async_register_websocket_api(hass)
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
return await component.async_setup_entry(entry)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
return await component.async_unload_entry(entry)

View File

@ -0,0 +1,12 @@
"""Constants for assist satellite."""
from enum import IntFlag
DOMAIN = "assist_satellite"
class AssistSatelliteEntityFeature(IntFlag):
"""Supported features of Assist satellite entity."""
ANNOUNCE = 1
"""Device supports remotely triggered announcements."""

View File

@ -0,0 +1,332 @@
"""Assist satellite entity."""
from abc import abstractmethod
import asyncio
from collections.abc import AsyncIterable
from enum import StrEnum
import logging
import time
from typing import Any, Final, final
from homeassistant.components import media_source, stt, tts
from homeassistant.components.assist_pipeline import (
OPTION_PREFERRED,
AudioSettings,
PipelineEvent,
PipelineEventType,
PipelineStage,
async_get_pipeline,
async_get_pipelines,
async_pipeline_from_audio_stream,
vad,
)
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.components.tts.media_source import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import Context, callback
from homeassistant.helpers import entity
from homeassistant.helpers.entity import EntityDescription
from homeassistant.util import ulid
from .const import AssistSatelliteEntityFeature
from .errors import AssistSatelliteError, SatelliteBusyError
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
_LOGGER = logging.getLogger(__name__)
class AssistSatelliteState(StrEnum):
"""Valid states of an Assist satellite entity."""
LISTENING_WAKE_WORD = "listening_wake_word"
"""Device is streaming audio for wake word detection to Home Assistant."""
LISTENING_COMMAND = "listening_command"
"""Device is streaming audio with the voice command to Home Assistant."""
PROCESSING = "processing"
"""Home Assistant is processing the voice command."""
RESPONDING = "responding"
"""Device is speaking the response."""
class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True):
"""A class that describes Assist satellite entities."""
class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite."""
entity_description: AssistSatelliteEntityDescription
_attr_should_poll = False
_attr_supported_features = AssistSatelliteEntityFeature(0)
_attr_pipeline_entity_id: str | None = None
_attr_vad_sensitivity_entity_id: str | None = None
_conversation_id: str | None = None
_conversation_id_time: float | None = None
_run_has_tts: bool = False
_is_announcing = False
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
__assist_satellite_state: AssistSatelliteState | None = None
@final
@property
def state(self) -> str | None:
"""Return state of the entity."""
return self.__assist_satellite_state
@property
def pipeline_entity_id(self) -> str | None:
"""Entity ID of the pipeline to use for the next conversation."""
return self._attr_pipeline_entity_id
@property
def vad_sensitivity_entity_id(self) -> str | None:
"""Entity ID of the VAD sensitivity to use for the next conversation."""
return self._attr_vad_sensitivity_entity_id
async def async_intercept_wake_word(self) -> str | None:
"""Intercept the next wake word from the satellite.
Returns the detected wake word phrase or None.
"""
if self._wake_word_intercept_future is not None:
raise SatelliteBusyError("Wake word interception already in progress")
# Will cause next wake word to be intercepted in
# async_accept_pipeline_from_satellite
self._wake_word_intercept_future = asyncio.Future()
_LOGGER.debug("Next wake word will be intercepted: %s", self.entity_id)
try:
return await self._wake_word_intercept_future
finally:
self._wake_word_intercept_future = None
async def async_internal_announce(
self,
message: str | None = None,
media_id: str | None = None,
) -> None:
"""Play and show an announcement on the satellite.
If media_id is not provided, message is synthesized to
audio with the selected pipeline.
If media_id is provided, it is played directly. It is possible
to omit the message and the satellite will not show any text.
Calls async_announce with message and media id.
"""
if message is None:
message = ""
if not media_id:
# Synthesize audio and get URL
pipeline_id = self._resolve_pipeline()
pipeline = async_get_pipeline(self.hass, pipeline_id)
tts_options: dict[str, Any] = {}
if pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
media_id = tts_generate_media_source_id(
self.hass,
message,
engine=pipeline.tts_engine,
language=pipeline.tts_language,
options=tts_options,
)
if media_source.is_media_source_id(media_id):
media = await media_source.async_resolve_media(
self.hass,
media_id,
None,
)
media_id = media.url
# Resolve to full URL
media_id = async_process_play_media_url(self.hass, media_id)
if self._is_announcing:
raise SatelliteBusyError
self._is_announcing = True
try:
# Block until announcement is finished
await self.async_announce(message, media_id)
finally:
self._is_announcing = False
async def async_announce(self, message: str, media_id: str) -> None:
"""Announce media on the satellite.
Should block until the announcement is done playing.
"""
raise NotImplementedError
async def async_accept_pipeline_from_satellite(
self,
audio_stream: AsyncIterable[bytes],
start_stage: PipelineStage = PipelineStage.STT,
end_stage: PipelineStage = PipelineStage.TTS,
wake_word_phrase: str | None = None,
) -> None:
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
if self._wake_word_intercept_future and start_stage in (
PipelineStage.WAKE_WORD,
PipelineStage.STT,
):
if start_stage == PipelineStage.WAKE_WORD:
self._wake_word_intercept_future.set_exception(
AssistSatelliteError(
"Only on-device wake words currently supported"
)
)
return
# Intercepting wake word and immediately end pipeline
_LOGGER.debug(
"Intercepted wake word: %s (entity_id=%s)",
wake_word_phrase,
self.entity_id,
)
if wake_word_phrase is None:
self._wake_word_intercept_future.set_exception(
AssistSatelliteError("No wake word phrase provided")
)
else:
self._wake_word_intercept_future.set_result(wake_word_phrase)
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
return
device_id = self.registry_entry.device_id if self.registry_entry else None
# Refresh context if necessary
if (
(self._context is None)
or (self._context_set is None)
or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS)
):
self.async_set_context(Context())
assert self._context is not None
# Reset conversation id if necessary
if (self._conversation_id_time is None) or (
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
):
self._conversation_id = None
if self._conversation_id is None:
self._conversation_id = ulid.ulid()
# Update timeout
self._conversation_id_time = time.monotonic()
# Set entity state based on pipeline events
self._run_has_tts = False
await async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=self._resolve_pipeline(),
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output="wav",
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()
),
start_stage=start_stage,
end_stage=end_stage,
)
@abstractmethod
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
@callback
def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
if event.type is PipelineEventType.WAKE_WORD_START:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
elif event.type is PipelineEventType.STT_START:
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
elif event.type is PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING)
elif event.type is PipelineEventType.TTS_START:
# Wait until tts_response_finished is called to return to waiting state
self._run_has_tts = True
self._set_state(AssistSatelliteState.RESPONDING)
elif event.type is PipelineEventType.RUN_END:
if not self._run_has_tts:
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
self.on_pipeline_event(event)
@callback
def _set_state(self, state: AssistSatelliteState) -> None:
"""Set the entity's state."""
self.__assist_satellite_state = state
self.async_write_ha_state()
@callback
def tts_response_finished(self) -> None:
"""Tell entity that the text-to-speech response has finished playing."""
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
@callback
def _resolve_pipeline(self) -> str | None:
"""Resolve pipeline from select entity to id.
Return None to make async_get_pipeline look up the preferred pipeline.
"""
if not (pipeline_entity_id := self.pipeline_entity_id):
return None
if (pipeline_entity_state := self.hass.states.get(pipeline_entity_id)) is None:
raise RuntimeError("Pipeline entity not found")
if pipeline_entity_state.state != OPTION_PREFERRED:
# Resolve pipeline by name
for pipeline in async_get_pipelines(self.hass):
if pipeline.name == pipeline_entity_state.state:
return pipeline.id
return None
@callback
def _resolve_vad_sensitivity(self) -> float:
"""Resolve VAD sensitivity from select entity to enum."""
vad_sensitivity = vad.VadSensitivity.DEFAULT
if vad_sensitivity_entity_id := self.vad_sensitivity_entity_id:
if (
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
) is None:
raise RuntimeError("VAD sensitivity entity not found")
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
return vad.VadSensitivity.to_seconds(vad_sensitivity)

View File

@ -0,0 +1,11 @@
"""Errors for assist satellite."""
from homeassistant.exceptions import HomeAssistantError
class AssistSatelliteError(HomeAssistantError):
"""Base class for assist satellite errors."""
class SatelliteBusyError(AssistSatelliteError):
"""Satellite is busy and cannot handle the request."""

View File

@ -0,0 +1,12 @@
{
"entity_component": {
"_": {
"default": "mdi:account-voice"
}
},
"services": {
"announce": {
"service": "mdi:bullhorn"
}
}
}

View File

@ -0,0 +1,9 @@
{
"domain": "assist_satellite",
"name": "Assist Satellite",
"codeowners": ["@home-assistant/core", "@synesthesiam"],
"dependencies": ["assist_pipeline", "stt", "tts"],
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
"integration_type": "entity",
"quality_scale": "internal"
}

View File

@ -0,0 +1,16 @@
announce:
target:
entity:
domain: assist_satellite
supported_features:
- assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
fields:
message:
required: false
example: "Time to wake up!"
selector:
text:
media_id:
required: false
selector:
text:

View File

@ -0,0 +1,30 @@
{
"title": "Assist satellite",
"entity_component": {
"_": {
"name": "Assist satellite",
"state": {
"listening_wake_word": "Wake word",
"listening_command": "Voice command",
"responding": "Responding",
"processing": "Processing"
}
}
},
"services": {
"announce": {
"name": "Announce",
"description": "Let the satellite announce a message.",
"fields": {
"message": {
"name": "Message",
"description": "The message to announce."
},
"media_id": {
"name": "Media ID",
"description": "The media ID to announce instead of using text-to-speech."
}
}
}
}
}

View File

@ -0,0 +1,46 @@
"""Assist satellite Websocket API."""
from typing import Any
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from .const import DOMAIN
from .entity import AssistSatelliteEntity
@callback
def async_register_websocket_api(hass: HomeAssistant) -> None:
"""Register the websocket API."""
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
@callback
@websocket_api.websocket_command(
{
vol.Required("type"): "assist_satellite/intercept_wake_word",
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
}
)
@websocket_api.require_admin
@websocket_api.async_response
async def websocket_intercept_wake_word(
hass: HomeAssistant,
connection: websocket_api.connection.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Intercept the next wake word from a satellite."""
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
satellite = component.get_entity(msg["entity_id"])
if satellite is None:
connection.send_error(
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
)
return
wake_word_phrase = await satellite.async_intercept_wake_word()
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})

View File

@ -41,6 +41,7 @@ class Platform(StrEnum):
AIR_QUALITY = "air_quality"
ALARM_CONTROL_PANEL = "alarm_control_panel"
ASSIST_SATELLITE = "assist_satellite"
BINARY_SENSOR = "binary_sensor"
BUTTON = "button"
CALENDAR = "calendar"

View File

@ -705,6 +705,16 @@ disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.assist_satellite.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.asuswrt.*]
check_untyped_defs = true
disallow_incomplete_defs = true

View File

@ -23,7 +23,7 @@ RUN --mount=from=ghcr.io/astral-sh/uv:0.2.27,source=/uv,target=/bin/uv \
-c /usr/src/homeassistant/homeassistant/package_constraints.txt \
-r /usr/src/homeassistant/requirements.txt \
stdlib-list==0.10.0 pipdeptree==2.23.1 tqdm==4.66.4 ruff==0.6.2 \
PyTurboJPEG==1.7.5 ha-ffmpeg==3.2.0 hassil==1.7.4 home-assistant-intents==2024.9.4 mutagen==1.47.0
PyTurboJPEG==1.7.5 ha-ffmpeg==3.2.0 hassil==1.7.4 home-assistant-intents==2024.9.4 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2
LABEL "name"="hassfest"
LABEL "maintainer"="Home Assistant <hello@home-assistant.io>"

View File

@ -0,0 +1,3 @@
"""Tests for Assist Satellite."""
ENTITY_ID = "assist_satellite.test_entity"

View File

@ -0,0 +1,107 @@
"""Test helpers for Assist Satellite."""
import pathlib
from unittest.mock import Mock
import pytest
from homeassistant.components.assist_pipeline import PipelineEvent
from homeassistant.components.assist_satellite import (
DOMAIN as AS_DOMAIN,
AssistSatelliteEntity,
AssistSatelliteEntityFeature,
)
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.common import (
MockConfigEntry,
MockModule,
mock_config_flow,
mock_integration,
mock_platform,
setup_test_component_platform,
)
TEST_DOMAIN = "test"
@pytest.fixture(autouse=True)
def mock_tts(mock_tts_cache_dir: pathlib.Path) -> None:
"""Mock TTS cache dir fixture."""
class MockAssistSatellite(AssistSatelliteEntity):
"""Mock Assist Satellite Entity."""
_attr_name = "Test Entity"
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
def __init__(self) -> None:
"""Initialize the mock entity."""
self.events = []
self.announcements = []
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
self.events.append(event)
async def async_announce(self, message: str, media_id: str) -> None:
"""Announce media on a device."""
self.announcements.append((message, media_id))
@pytest.fixture
def entity() -> MockAssistSatellite:
"""Mock Assist Satellite Entity."""
return MockAssistSatellite()
@pytest.fixture
def config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Mock config entry."""
entry = MockConfigEntry(domain=TEST_DOMAIN)
entry.add_to_hass(hass)
return entry
@pytest.fixture
async def init_components(
hass: HomeAssistant,
config_entry: ConfigEntry,
entity: MockAssistSatellite,
) -> None:
"""Initialize components."""
assert await async_setup_component(hass, "homeassistant", {})
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setups(config_entry, [AS_DOMAIN])
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, AS_DOMAIN)
return True
mock_integration(
hass,
MockModule(
TEST_DOMAIN,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
setup_test_component_platform(hass, AS_DOMAIN, [entity], from_config_entry=True)
mock_platform(hass, f"{TEST_DOMAIN}.config_flow", Mock())
with mock_config_flow(TEST_DOMAIN, ConfigFlow):
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
return config_entry

View File

@ -0,0 +1,332 @@
"""Test the Assist Satellite entity."""
import asyncio
from unittest.mock import patch
import pytest
from homeassistant.components import stt
from homeassistant.components.assist_pipeline import (
OPTION_PREFERRED,
AudioSettings,
Pipeline,
PipelineEvent,
PipelineEventType,
PipelineStage,
async_get_pipeline,
async_update_pipeline,
vad,
)
from homeassistant.components.assist_satellite import SatelliteBusyError
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
from homeassistant.components.media_source import PlayMedia
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import Context, HomeAssistant
from . import ENTITY_ID
from .conftest import MockAssistSatellite
async def test_entity_state(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None:
"""Test entity state represent events."""
state = hass.states.get(ENTITY_ID)
assert state is not None
assert state.state == STATE_UNKNOWN
context = Context()
audio_stream = object()
entity.async_set_context(context)
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
) as mock_start_pipeline:
await entity.async_accept_pipeline_from_satellite(audio_stream)
assert mock_start_pipeline.called
kwargs = mock_start_pipeline.call_args[1]
assert kwargs["context"] is context
assert kwargs["event_callback"] == entity._internal_on_pipeline_event
assert kwargs["stt_metadata"] == stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
assert kwargs["stt_stream"] is audio_stream
assert kwargs["pipeline_id"] is None
assert kwargs["device_id"] is None
assert kwargs["tts_audio_output"] == "wav"
assert kwargs["wake_word_phrase"] is None
assert kwargs["audio_settings"] == AudioSettings(
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
)
assert kwargs["start_stage"] == PipelineStage.STT
assert kwargs["end_stage"] == PipelineStage.TTS
for event_type, expected_state in (
(PipelineEventType.RUN_START, STATE_UNKNOWN),
(PipelineEventType.RUN_END, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.WAKE_WORD_START, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.WAKE_WORD_END, AssistSatelliteState.LISTENING_WAKE_WORD),
(PipelineEventType.STT_START, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.STT_VAD_START, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.STT_VAD_END, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.STT_END, AssistSatelliteState.LISTENING_COMMAND),
(PipelineEventType.INTENT_START, AssistSatelliteState.PROCESSING),
(PipelineEventType.INTENT_END, AssistSatelliteState.PROCESSING),
(PipelineEventType.TTS_START, AssistSatelliteState.RESPONDING),
(PipelineEventType.TTS_END, AssistSatelliteState.RESPONDING),
(PipelineEventType.ERROR, AssistSatelliteState.RESPONDING),
):
kwargs["event_callback"](PipelineEvent(event_type, {}))
state = hass.states.get(ENTITY_ID)
assert state.state == expected_state, event_type
entity.tts_response_finished()
state = hass.states.get(ENTITY_ID)
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
@pytest.mark.parametrize(
("service_data", "expected_params"),
[
(
{"message": "Hello"},
("Hello", "https://www.home-assistant.io/resolved.mp3"),
),
(
{
"message": "Hello",
"media_id": "http://example.com/bla.mp3",
},
("Hello", "http://example.com/bla.mp3"),
),
(
{"media_id": "http://example.com/bla.mp3"},
("", "http://example.com/bla.mp3"),
),
],
)
async def test_announce(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
service_data: dict,
expected_params: tuple[str, str],
) -> None:
"""Test announcing on a device."""
await async_update_pipeline(
hass,
async_get_pipeline(hass),
tts_engine="tts.mock_entity",
tts_language="en",
tts_voice="test-voice",
)
with (
patch(
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
return_value="media-source://bla",
),
patch(
"homeassistant.components.media_source.async_resolve_media",
return_value=PlayMedia(
url="https://www.home-assistant.io/resolved.mp3",
mime_type="audio/mp3",
),
),
):
await hass.services.async_call(
"assist_satellite",
"announce",
service_data,
target={"entity_id": "assist_satellite.test_entity"},
blocking=True,
)
assert entity.announcements[0] == expected_params
async def test_announce_busy(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
) -> None:
"""Test that announcing while an announcement is in progress raises an error."""
media_id = "https://www.home-assistant.io/resolved.mp3"
announce_started = asyncio.Event()
got_error = asyncio.Event()
async def async_announce(message, media_id):
announce_started.set()
# Block so we can do another announcement
await got_error.wait()
with patch.object(entity, "async_announce", new=async_announce):
announce_task = asyncio.create_task(
entity.async_internal_announce(media_id=media_id)
)
async with asyncio.timeout(1):
await announce_started.wait()
# Try to do a second announcement
with pytest.raises(SatelliteBusyError):
await entity.async_internal_announce(media_id=media_id)
# Avoid lingering task
got_error.set()
await announce_task
async def test_context_refresh(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None:
"""Test that the context will be automatically refreshed."""
audio_stream = object()
# Remove context
entity._context = None
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
):
await entity.async_accept_pipeline_from_satellite(audio_stream)
# Context should have been refreshed
assert entity._context is not None
async def test_pipeline_entity(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None:
"""Test getting pipeline from an entity."""
audio_stream = object()
pipeline = Pipeline(
conversation_engine="test",
conversation_language="en",
language="en",
name="test-pipeline",
stt_engine=None,
stt_language=None,
tts_engine=None,
tts_language=None,
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
pipeline_entity_id = "select.pipeline"
hass.states.async_set(pipeline_entity_id, pipeline.name)
entity._attr_pipeline_entity_id = pipeline_entity_id
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, pipeline_id: str, **kwargs):
assert pipeline_id == pipeline.id
done.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.assist_satellite.entity.async_get_pipelines",
return_value=[pipeline],
),
):
async with asyncio.timeout(1):
await entity.async_accept_pipeline_from_satellite(audio_stream)
await done.wait()
async def test_pipeline_entity_preferred(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None:
"""Test getting pipeline from an entity with a preferred state."""
audio_stream = object()
pipeline_entity_id = "select.pipeline"
hass.states.async_set(pipeline_entity_id, OPTION_PREFERRED)
entity._attr_pipeline_entity_id = pipeline_entity_id
done = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, pipeline_id: str, **kwargs):
# Preferred pipeline
assert pipeline_id is None
done.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
):
async with asyncio.timeout(1):
await entity.async_accept_pipeline_from_satellite(audio_stream)
await done.wait()
async def test_vad_sensitivity_entity(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None:
"""Test getting vad sensitivity from an entity."""
audio_stream = object()
vad_sensitivity_entity_id = "select.vad_sensitivity"
hass.states.async_set(vad_sensitivity_entity_id, vad.VadSensitivity.AGGRESSIVE)
entity._attr_vad_sensitivity_entity_id = vad_sensitivity_entity_id
done = asyncio.Event()
async def async_pipeline_from_audio_stream(
*args, audio_settings: AudioSettings, **kwargs
):
# Verify vad sensitivity
assert audio_settings.silence_seconds == vad.VadSensitivity.to_seconds(
vad.VadSensitivity.AGGRESSIVE
)
done.set()
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
async with asyncio.timeout(1):
await entity.async_accept_pipeline_from_satellite(audio_stream)
await done.wait()
async def test_pipeline_entity_not_found(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None:
"""Test that setting the pipeline entity id to a non-existent entity raises an error."""
audio_stream = object()
# Set to an entity that doesn't exist
entity._attr_pipeline_entity_id = "select.pipeline"
with pytest.raises(RuntimeError):
await entity.async_accept_pipeline_from_satellite(audio_stream)
async def test_vad_sensitivity_entity_not_found(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None:
"""Test that setting the vad sensitivity entity id to a non-existent entity raises an error."""
audio_stream = object()
# Set to an entity that doesn't exist
entity._attr_vad_sensitivity_entity_id = "select.vad_sensitivity"
with pytest.raises(RuntimeError):
await entity.async_accept_pipeline_from_satellite(audio_stream)

View File

@ -0,0 +1,192 @@
"""Test WebSocket API."""
import asyncio
from homeassistant.components.assist_pipeline import PipelineStage
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from . import ENTITY_ID
from .conftest import MockAssistSatellite
from tests.common import MockUser
from tests.typing import WebSocketGenerator
async def test_intercept_wake_word(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test intercepting a wake word."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
for _ in range(3):
await asyncio.sleep(0)
await entity.async_accept_pipeline_from_satellite(
object(),
start_stage=PipelineStage.STT,
wake_word_phrase="ok, nabu",
)
response = await ws_client.receive_json()
assert response["success"]
assert response["result"] == {"wake_word_phrase": "ok, nabu"}
async def test_intercept_wake_word_requires_on_device_wake_word(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test intercepting a wake word fails if detection happens in HA."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
for _ in range(3):
await asyncio.sleep(0)
await entity.async_accept_pipeline_from_satellite(
object(),
# Emulate wake word processing in Home Assistant
start_stage=PipelineStage.WAKE_WORD,
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "home_assistant_error",
"message": "Only on-device wake words currently supported",
}
async def test_intercept_wake_word_requires_wake_word_phrase(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test intercepting a wake word fails if detection happens in HA."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
for _ in range(3):
await asyncio.sleep(0)
await entity.async_accept_pipeline_from_satellite(
object(),
start_stage=PipelineStage.STT,
# We are not passing wake word phrase
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "home_assistant_error",
"message": "No wake word phrase provided",
}
async def test_intercept_wake_word_require_admin(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
hass_admin_user: MockUser,
) -> None:
"""Test intercepting a wake word requires admin access."""
# Remove admin permission and verify we're not allowed
hass_admin_user.groups = []
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "unauthorized",
"message": "Unauthorized",
}
async def test_intercept_wake_word_invalid_satellite(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test intercepting a wake word requires admin access."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": "assist_satellite.invalid",
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "not_found",
"message": "Entity not found",
}
async def test_intercept_wake_word_twice(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
hass_ws_client: WebSocketGenerator,
) -> None:
"""Test intercepting a wake word requires admin access."""
ws_client = await hass_ws_client(hass)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
await ws_client.send_json_auto_id(
{
"type": "assist_satellite/intercept_wake_word",
"entity_id": ENTITY_ID,
}
)
response = await ws_client.receive_json()
assert not response["success"]
assert response["error"] == {
"code": "home_assistant_error",
"message": "Wake word interception already in progress",
}