mirror of
https://github.com/home-assistant/core.git
synced 2025-06-25 01:21:51 +02:00
Add assist satellite entity component (#125351)
* Add assist_satellite * Update homeassistant/components/assist_satellite/manifest.json Co-authored-by: Paulus Schoutsen <balloob@gmail.com> * Update homeassistant/components/assist_satellite/manifest.json Co-authored-by: Paulus Schoutsen <balloob@gmail.com> * Add platform constant * Update Dockerfile * Apply suggestions from code review Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Address comments * Update docstring async_internal_announce * Update CODEOWNERS --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com> Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
@ -14,6 +14,7 @@ core: &core
|
||||
base_platforms: &base_platforms
|
||||
- homeassistant/components/air_quality/**
|
||||
- homeassistant/components/alarm_control_panel/**
|
||||
- homeassistant/components/assist_satellite/**
|
||||
- homeassistant/components/binary_sensor/**
|
||||
- homeassistant/components/button/**
|
||||
- homeassistant/components/calendar/**
|
||||
|
@ -95,6 +95,7 @@ homeassistant.components.aruba.*
|
||||
homeassistant.components.arwn.*
|
||||
homeassistant.components.aseko_pool_live.*
|
||||
homeassistant.components.assist_pipeline.*
|
||||
homeassistant.components.assist_satellite.*
|
||||
homeassistant.components.asuswrt.*
|
||||
homeassistant.components.autarco.*
|
||||
homeassistant.components.auth.*
|
||||
|
@ -143,6 +143,8 @@ build.json @home-assistant/supervisor
|
||||
/tests/components/aseko_pool_live/ @milanmeu
|
||||
/homeassistant/components/assist_pipeline/ @balloob @synesthesiam
|
||||
/tests/components/assist_pipeline/ @balloob @synesthesiam
|
||||
/homeassistant/components/assist_satellite/ @home-assistant/core @synesthesiam
|
||||
/tests/components/assist_satellite/ @home-assistant/core @synesthesiam
|
||||
/homeassistant/components/asuswrt/ @kennedyshead @ollo69
|
||||
/tests/components/asuswrt/ @kennedyshead @ollo69
|
||||
/homeassistant/components/atag/ @MatsNL
|
||||
|
@ -17,6 +17,7 @@ from .const import (
|
||||
DATA_LAST_WAKE_UP,
|
||||
DOMAIN,
|
||||
EVENT_RECORDING,
|
||||
OPTION_PREFERRED,
|
||||
SAMPLE_CHANNELS,
|
||||
SAMPLE_RATE,
|
||||
SAMPLE_WIDTH,
|
||||
@ -58,6 +59,7 @@ __all__ = (
|
||||
"PipelineNotFound",
|
||||
"WakeWordSettings",
|
||||
"EVENT_RECORDING",
|
||||
"OPTION_PREFERRED",
|
||||
"SAMPLES_PER_CHUNK",
|
||||
"SAMPLE_RATE",
|
||||
"SAMPLE_WIDTH",
|
||||
|
@ -22,3 +22,5 @@ SAMPLE_CHANNELS = 1 # mono
|
||||
MS_PER_CHUNK = 10
|
||||
SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
|
||||
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
|
@ -9,12 +9,10 @@ from homeassistant.const import EntityCategory, Platform
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import collection, entity_registry as er, restore_state
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DOMAIN, OPTION_PREFERRED
|
||||
from .pipeline import AssistDevice, PipelineData, PipelineStorageCollection
|
||||
from .vad import VadSensitivity
|
||||
|
||||
OPTION_PREFERRED = "preferred"
|
||||
|
||||
|
||||
@callback
|
||||
def get_chosen_pipeline(
|
||||
|
65
homeassistant/components/assist_satellite/__init__.py
Normal file
65
homeassistant/components/assist_satellite/__init__.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Base class for assist satellite entities."""
|
||||
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DOMAIN, AssistSatelliteEntityFeature
|
||||
from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
|
||||
from .errors import SatelliteBusyError
|
||||
from .websocket_api import async_register_websocket_api
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"AssistSatelliteEntity",
|
||||
"AssistSatelliteEntityDescription",
|
||||
"AssistSatelliteEntityFeature",
|
||||
"SatelliteBusyError",
|
||||
]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PLATFORM_SCHEMA_BASE = cv.PLATFORM_SCHEMA_BASE
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
component = hass.data[DOMAIN] = EntityComponent[AssistSatelliteEntity](
|
||||
_LOGGER, DOMAIN, hass
|
||||
)
|
||||
await component.async_setup(config)
|
||||
|
||||
component.async_register_entity_service(
|
||||
"announce",
|
||||
vol.All(
|
||||
cv.make_entity_service_schema(
|
||||
{
|
||||
vol.Optional("message"): str,
|
||||
vol.Optional("media_id"): str,
|
||||
}
|
||||
),
|
||||
cv.has_at_least_one_key("message", "media_id"),
|
||||
),
|
||||
"async_internal_announce",
|
||||
[AssistSatelliteEntityFeature.ANNOUNCE],
|
||||
)
|
||||
async_register_websocket_api(hass)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up a config entry."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
return await component.async_setup_entry(entry)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
return await component.async_unload_entry(entry)
|
12
homeassistant/components/assist_satellite/const.py
Normal file
12
homeassistant/components/assist_satellite/const.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""Constants for assist satellite."""
|
||||
|
||||
from enum import IntFlag
|
||||
|
||||
DOMAIN = "assist_satellite"
|
||||
|
||||
|
||||
class AssistSatelliteEntityFeature(IntFlag):
|
||||
"""Supported features of Assist satellite entity."""
|
||||
|
||||
ANNOUNCE = 1
|
||||
"""Device supports remotely triggered announcements."""
|
332
homeassistant/components/assist_satellite/entity.py
Normal file
332
homeassistant/components/assist_satellite/entity.py
Normal file
@ -0,0 +1,332 @@
|
||||
"""Assist satellite entity."""
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from enum import StrEnum
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Final, final
|
||||
|
||||
from homeassistant.components import media_source, stt, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
OPTION_PREFERRED,
|
||||
AudioSettings,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
async_get_pipeline,
|
||||
async_get_pipelines,
|
||||
async_pipeline_from_audio_stream,
|
||||
vad,
|
||||
)
|
||||
from homeassistant.components.media_player import async_process_play_media_url
|
||||
from homeassistant.components.tts.media_source import (
|
||||
generate_media_source_id as tts_generate_media_source_id,
|
||||
)
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.helpers import entity
|
||||
from homeassistant.helpers.entity import EntityDescription
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from .const import AssistSatelliteEntityFeature
|
||||
from .errors import AssistSatelliteError, SatelliteBusyError
|
||||
|
||||
_CONVERSATION_TIMEOUT_SEC: Final = 5 * 60 # 5 minutes
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssistSatelliteState(StrEnum):
|
||||
"""Valid states of an Assist satellite entity."""
|
||||
|
||||
LISTENING_WAKE_WORD = "listening_wake_word"
|
||||
"""Device is streaming audio for wake word detection to Home Assistant."""
|
||||
|
||||
LISTENING_COMMAND = "listening_command"
|
||||
"""Device is streaming audio with the voice command to Home Assistant."""
|
||||
|
||||
PROCESSING = "processing"
|
||||
"""Home Assistant is processing the voice command."""
|
||||
|
||||
RESPONDING = "responding"
|
||||
"""Device is speaking the response."""
|
||||
|
||||
|
||||
class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True):
|
||||
"""A class that describes Assist satellite entities."""
|
||||
|
||||
|
||||
class AssistSatelliteEntity(entity.Entity):
|
||||
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
||||
|
||||
entity_description: AssistSatelliteEntityDescription
|
||||
_attr_should_poll = False
|
||||
_attr_supported_features = AssistSatelliteEntityFeature(0)
|
||||
_attr_pipeline_entity_id: str | None = None
|
||||
_attr_vad_sensitivity_entity_id: str | None = None
|
||||
|
||||
_conversation_id: str | None = None
|
||||
_conversation_id_time: float | None = None
|
||||
|
||||
_run_has_tts: bool = False
|
||||
_is_announcing = False
|
||||
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
||||
|
||||
__assist_satellite_state: AssistSatelliteState | None = None
|
||||
|
||||
@final
|
||||
@property
|
||||
def state(self) -> str | None:
|
||||
"""Return state of the entity."""
|
||||
return self.__assist_satellite_state
|
||||
|
||||
@property
|
||||
def pipeline_entity_id(self) -> str | None:
|
||||
"""Entity ID of the pipeline to use for the next conversation."""
|
||||
return self._attr_pipeline_entity_id
|
||||
|
||||
@property
|
||||
def vad_sensitivity_entity_id(self) -> str | None:
|
||||
"""Entity ID of the VAD sensitivity to use for the next conversation."""
|
||||
return self._attr_vad_sensitivity_entity_id
|
||||
|
||||
async def async_intercept_wake_word(self) -> str | None:
|
||||
"""Intercept the next wake word from the satellite.
|
||||
|
||||
Returns the detected wake word phrase or None.
|
||||
"""
|
||||
if self._wake_word_intercept_future is not None:
|
||||
raise SatelliteBusyError("Wake word interception already in progress")
|
||||
|
||||
# Will cause next wake word to be intercepted in
|
||||
# async_accept_pipeline_from_satellite
|
||||
self._wake_word_intercept_future = asyncio.Future()
|
||||
|
||||
_LOGGER.debug("Next wake word will be intercepted: %s", self.entity_id)
|
||||
|
||||
try:
|
||||
return await self._wake_word_intercept_future
|
||||
finally:
|
||||
self._wake_word_intercept_future = None
|
||||
|
||||
async def async_internal_announce(
|
||||
self,
|
||||
message: str | None = None,
|
||||
media_id: str | None = None,
|
||||
) -> None:
|
||||
"""Play and show an announcement on the satellite.
|
||||
|
||||
If media_id is not provided, message is synthesized to
|
||||
audio with the selected pipeline.
|
||||
|
||||
If media_id is provided, it is played directly. It is possible
|
||||
to omit the message and the satellite will not show any text.
|
||||
|
||||
Calls async_announce with message and media id.
|
||||
"""
|
||||
if message is None:
|
||||
message = ""
|
||||
|
||||
if not media_id:
|
||||
# Synthesize audio and get URL
|
||||
pipeline_id = self._resolve_pipeline()
|
||||
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
||||
|
||||
tts_options: dict[str, Any] = {}
|
||||
if pipeline.tts_voice is not None:
|
||||
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
|
||||
|
||||
media_id = tts_generate_media_source_id(
|
||||
self.hass,
|
||||
message,
|
||||
engine=pipeline.tts_engine,
|
||||
language=pipeline.tts_language,
|
||||
options=tts_options,
|
||||
)
|
||||
|
||||
if media_source.is_media_source_id(media_id):
|
||||
media = await media_source.async_resolve_media(
|
||||
self.hass,
|
||||
media_id,
|
||||
None,
|
||||
)
|
||||
media_id = media.url
|
||||
|
||||
# Resolve to full URL
|
||||
media_id = async_process_play_media_url(self.hass, media_id)
|
||||
|
||||
if self._is_announcing:
|
||||
raise SatelliteBusyError
|
||||
|
||||
self._is_announcing = True
|
||||
|
||||
try:
|
||||
# Block until announcement is finished
|
||||
await self.async_announce(message, media_id)
|
||||
finally:
|
||||
self._is_announcing = False
|
||||
|
||||
async def async_announce(self, message: str, media_id: str) -> None:
|
||||
"""Announce media on the satellite.
|
||||
|
||||
Should block until the announcement is done playing.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_accept_pipeline_from_satellite(
|
||||
self,
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
start_stage: PipelineStage = PipelineStage.STT,
|
||||
end_stage: PipelineStage = PipelineStage.TTS,
|
||||
wake_word_phrase: str | None = None,
|
||||
) -> None:
|
||||
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
|
||||
if self._wake_word_intercept_future and start_stage in (
|
||||
PipelineStage.WAKE_WORD,
|
||||
PipelineStage.STT,
|
||||
):
|
||||
if start_stage == PipelineStage.WAKE_WORD:
|
||||
self._wake_word_intercept_future.set_exception(
|
||||
AssistSatelliteError(
|
||||
"Only on-device wake words currently supported"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Intercepting wake word and immediately end pipeline
|
||||
_LOGGER.debug(
|
||||
"Intercepted wake word: %s (entity_id=%s)",
|
||||
wake_word_phrase,
|
||||
self.entity_id,
|
||||
)
|
||||
|
||||
if wake_word_phrase is None:
|
||||
self._wake_word_intercept_future.set_exception(
|
||||
AssistSatelliteError("No wake word phrase provided")
|
||||
)
|
||||
else:
|
||||
self._wake_word_intercept_future.set_result(wake_word_phrase)
|
||||
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
|
||||
return
|
||||
|
||||
device_id = self.registry_entry.device_id if self.registry_entry else None
|
||||
|
||||
# Refresh context if necessary
|
||||
if (
|
||||
(self._context is None)
|
||||
or (self._context_set is None)
|
||||
or ((time.time() - self._context_set) > entity.CONTEXT_RECENT_TIME_SECONDS)
|
||||
):
|
||||
self.async_set_context(Context())
|
||||
|
||||
assert self._context is not None
|
||||
|
||||
# Reset conversation id if necessary
|
||||
if (self._conversation_id_time is None) or (
|
||||
(time.monotonic() - self._conversation_id_time) > _CONVERSATION_TIMEOUT_SEC
|
||||
):
|
||||
self._conversation_id = None
|
||||
|
||||
if self._conversation_id is None:
|
||||
self._conversation_id = ulid.ulid()
|
||||
|
||||
# Update timeout
|
||||
self._conversation_id_time = time.monotonic()
|
||||
|
||||
# Set entity state based on pipeline events
|
||||
self._run_has_tts = False
|
||||
|
||||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=self._context,
|
||||
event_callback=self._internal_on_pipeline_event,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_stream,
|
||||
pipeline_id=self._resolve_pipeline(),
|
||||
conversation_id=self._conversation_id,
|
||||
device_id=device_id,
|
||||
tts_audio_output="wav",
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
audio_settings=AudioSettings(
|
||||
silence_seconds=self._resolve_vad_sensitivity()
|
||||
),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
|
||||
@callback
|
||||
def _internal_on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Set state based on pipeline stage."""
|
||||
if event.type is PipelineEventType.WAKE_WORD_START:
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
elif event.type is PipelineEventType.STT_START:
|
||||
self._set_state(AssistSatelliteState.LISTENING_COMMAND)
|
||||
elif event.type is PipelineEventType.INTENT_START:
|
||||
self._set_state(AssistSatelliteState.PROCESSING)
|
||||
elif event.type is PipelineEventType.TTS_START:
|
||||
# Wait until tts_response_finished is called to return to waiting state
|
||||
self._run_has_tts = True
|
||||
self._set_state(AssistSatelliteState.RESPONDING)
|
||||
elif event.type is PipelineEventType.RUN_END:
|
||||
if not self._run_has_tts:
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
|
||||
self.on_pipeline_event(event)
|
||||
|
||||
@callback
|
||||
def _set_state(self, state: AssistSatelliteState) -> None:
|
||||
"""Set the entity's state."""
|
||||
self.__assist_satellite_state = state
|
||||
self.async_write_ha_state()
|
||||
|
||||
@callback
|
||||
def tts_response_finished(self) -> None:
|
||||
"""Tell entity that the text-to-speech response has finished playing."""
|
||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||
|
||||
@callback
|
||||
def _resolve_pipeline(self) -> str | None:
|
||||
"""Resolve pipeline from select entity to id.
|
||||
|
||||
Return None to make async_get_pipeline look up the preferred pipeline.
|
||||
"""
|
||||
if not (pipeline_entity_id := self.pipeline_entity_id):
|
||||
return None
|
||||
|
||||
if (pipeline_entity_state := self.hass.states.get(pipeline_entity_id)) is None:
|
||||
raise RuntimeError("Pipeline entity not found")
|
||||
|
||||
if pipeline_entity_state.state != OPTION_PREFERRED:
|
||||
# Resolve pipeline by name
|
||||
for pipeline in async_get_pipelines(self.hass):
|
||||
if pipeline.name == pipeline_entity_state.state:
|
||||
return pipeline.id
|
||||
|
||||
return None
|
||||
|
||||
@callback
|
||||
def _resolve_vad_sensitivity(self) -> float:
|
||||
"""Resolve VAD sensitivity from select entity to enum."""
|
||||
vad_sensitivity = vad.VadSensitivity.DEFAULT
|
||||
|
||||
if vad_sensitivity_entity_id := self.vad_sensitivity_entity_id:
|
||||
if (
|
||||
vad_sensitivity_state := self.hass.states.get(vad_sensitivity_entity_id)
|
||||
) is None:
|
||||
raise RuntimeError("VAD sensitivity entity not found")
|
||||
|
||||
vad_sensitivity = vad.VadSensitivity(vad_sensitivity_state.state)
|
||||
|
||||
return vad.VadSensitivity.to_seconds(vad_sensitivity)
|
11
homeassistant/components/assist_satellite/errors.py
Normal file
11
homeassistant/components/assist_satellite/errors.py
Normal file
@ -0,0 +1,11 @@
|
||||
"""Errors for assist satellite."""
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
|
||||
class AssistSatelliteError(HomeAssistantError):
|
||||
"""Base class for assist satellite errors."""
|
||||
|
||||
|
||||
class SatelliteBusyError(AssistSatelliteError):
|
||||
"""Satellite is busy and cannot handle the request."""
|
12
homeassistant/components/assist_satellite/icons.json
Normal file
12
homeassistant/components/assist_satellite/icons.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"entity_component": {
|
||||
"_": {
|
||||
"default": "mdi:account-voice"
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"announce": {
|
||||
"service": "mdi:bullhorn"
|
||||
}
|
||||
}
|
||||
}
|
9
homeassistant/components/assist_satellite/manifest.json
Normal file
9
homeassistant/components/assist_satellite/manifest.json
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"domain": "assist_satellite",
|
||||
"name": "Assist Satellite",
|
||||
"codeowners": ["@home-assistant/core", "@synesthesiam"],
|
||||
"dependencies": ["assist_pipeline", "stt", "tts"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
|
||||
"integration_type": "entity",
|
||||
"quality_scale": "internal"
|
||||
}
|
16
homeassistant/components/assist_satellite/services.yaml
Normal file
16
homeassistant/components/assist_satellite/services.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
announce:
|
||||
target:
|
||||
entity:
|
||||
domain: assist_satellite
|
||||
supported_features:
|
||||
- assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||
fields:
|
||||
message:
|
||||
required: false
|
||||
example: "Time to wake up!"
|
||||
selector:
|
||||
text:
|
||||
media_id:
|
||||
required: false
|
||||
selector:
|
||||
text:
|
30
homeassistant/components/assist_satellite/strings.json
Normal file
30
homeassistant/components/assist_satellite/strings.json
Normal file
@ -0,0 +1,30 @@
|
||||
{
|
||||
"title": "Assist satellite",
|
||||
"entity_component": {
|
||||
"_": {
|
||||
"name": "Assist satellite",
|
||||
"state": {
|
||||
"listening_wake_word": "Wake word",
|
||||
"listening_command": "Voice command",
|
||||
"responding": "Responding",
|
||||
"processing": "Processing"
|
||||
}
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"announce": {
|
||||
"name": "Announce",
|
||||
"description": "Let the satellite announce a message.",
|
||||
"fields": {
|
||||
"message": {
|
||||
"name": "Message",
|
||||
"description": "The message to announce."
|
||||
},
|
||||
"media_id": {
|
||||
"name": "Media ID",
|
||||
"description": "The media ID to announce instead of using text-to-speech."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
46
homeassistant/components/assist_satellite/websocket_api.py
Normal file
46
homeassistant/components/assist_satellite/websocket_api.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Assist satellite Websocket API."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import AssistSatelliteEntity
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||
"""Register the websocket API."""
|
||||
websocket_api.async_register_command(hass, websocket_intercept_wake_word)
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "assist_satellite/intercept_wake_word",
|
||||
vol.Required("entity_id"): cv.entity_domain(DOMAIN),
|
||||
}
|
||||
)
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.async_response
|
||||
async def websocket_intercept_wake_word(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.connection.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Intercept the next wake word from a satellite."""
|
||||
component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
|
||||
satellite = component.get_entity(msg["entity_id"])
|
||||
if satellite is None:
|
||||
connection.send_error(
|
||||
msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
|
||||
)
|
||||
return
|
||||
|
||||
wake_word_phrase = await satellite.async_intercept_wake_word()
|
||||
connection.send_result(msg["id"], {"wake_word_phrase": wake_word_phrase})
|
@ -41,6 +41,7 @@ class Platform(StrEnum):
|
||||
|
||||
AIR_QUALITY = "air_quality"
|
||||
ALARM_CONTROL_PANEL = "alarm_control_panel"
|
||||
ASSIST_SATELLITE = "assist_satellite"
|
||||
BINARY_SENSOR = "binary_sensor"
|
||||
BUTTON = "button"
|
||||
CALENDAR = "calendar"
|
||||
|
10
mypy.ini
10
mypy.ini
@ -705,6 +705,16 @@ disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.assist_satellite.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.asuswrt.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
|
@ -23,7 +23,7 @@ RUN --mount=from=ghcr.io/astral-sh/uv:0.2.27,source=/uv,target=/bin/uv \
|
||||
-c /usr/src/homeassistant/homeassistant/package_constraints.txt \
|
||||
-r /usr/src/homeassistant/requirements.txt \
|
||||
stdlib-list==0.10.0 pipdeptree==2.23.1 tqdm==4.66.4 ruff==0.6.2 \
|
||||
PyTurboJPEG==1.7.5 ha-ffmpeg==3.2.0 hassil==1.7.4 home-assistant-intents==2024.9.4 mutagen==1.47.0
|
||||
PyTurboJPEG==1.7.5 ha-ffmpeg==3.2.0 hassil==1.7.4 home-assistant-intents==2024.9.4 mutagen==1.47.0 pymicro-vad==1.0.1 pyspeex-noise==1.0.2
|
||||
|
||||
LABEL "name"="hassfest"
|
||||
LABEL "maintainer"="Home Assistant <hello@home-assistant.io>"
|
||||
|
3
tests/components/assist_satellite/__init__.py
Normal file
3
tests/components/assist_satellite/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""Tests for Assist Satellite."""
|
||||
|
||||
ENTITY_ID = "assist_satellite.test_entity"
|
107
tests/components/assist_satellite/conftest.py
Normal file
107
tests/components/assist_satellite/conftest.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""Test helpers for Assist Satellite."""
|
||||
|
||||
import pathlib
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.assist_pipeline import PipelineEvent
|
||||
from homeassistant.components.assist_satellite import (
|
||||
DOMAIN as AS_DOMAIN,
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityFeature,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import (
|
||||
MockConfigEntry,
|
||||
MockModule,
|
||||
mock_config_flow,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
setup_test_component_platform,
|
||||
)
|
||||
|
||||
TEST_DOMAIN = "test"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_tts(mock_tts_cache_dir: pathlib.Path) -> None:
|
||||
"""Mock TTS cache dir fixture."""
|
||||
|
||||
|
||||
class MockAssistSatellite(AssistSatelliteEntity):
|
||||
"""Mock Assist Satellite Entity."""
|
||||
|
||||
_attr_name = "Test Entity"
|
||||
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the mock entity."""
|
||||
self.events = []
|
||||
self.announcements = []
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
self.events.append(event)
|
||||
|
||||
async def async_announce(self, message: str, media_id: str) -> None:
|
||||
"""Announce media on a device."""
|
||||
self.announcements.append((message, media_id))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity() -> MockAssistSatellite:
|
||||
"""Mock Assist Satellite Entity."""
|
||||
return MockAssistSatellite()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||
"""Mock config entry."""
|
||||
entry = MockConfigEntry(domain=TEST_DOMAIN)
|
||||
entry.add_to_hass(hass)
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_components(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
) -> None:
|
||||
"""Initialize components."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
async def async_setup_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Set up test config entry."""
|
||||
await hass.config_entries.async_forward_entry_setups(config_entry, [AS_DOMAIN])
|
||||
return True
|
||||
|
||||
async def async_unload_entry_init(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> bool:
|
||||
"""Unload test config entry."""
|
||||
await hass.config_entries.async_forward_entry_unload(config_entry, AS_DOMAIN)
|
||||
return True
|
||||
|
||||
mock_integration(
|
||||
hass,
|
||||
MockModule(
|
||||
TEST_DOMAIN,
|
||||
async_setup_entry=async_setup_entry_init,
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
)
|
||||
setup_test_component_platform(hass, AS_DOMAIN, [entity], from_config_entry=True)
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow", Mock())
|
||||
|
||||
with mock_config_flow(TEST_DOMAIN, ConfigFlow):
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return config_entry
|
332
tests/components/assist_satellite/test_entity.py
Normal file
332
tests/components/assist_satellite/test_entity.py
Normal file
@ -0,0 +1,332 @@
|
||||
"""Test the Assist Satellite entity."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
OPTION_PREFERRED,
|
||||
AudioSettings,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
PipelineStage,
|
||||
async_get_pipeline,
|
||||
async_update_pipeline,
|
||||
vad,
|
||||
)
|
||||
from homeassistant.components.assist_satellite import SatelliteBusyError
|
||||
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
|
||||
from homeassistant.components.media_source import PlayMedia
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
|
||||
from . import ENTITY_ID
|
||||
from .conftest import MockAssistSatellite
|
||||
|
||||
|
||||
async def test_entity_state(
|
||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
"""Test entity state represent events."""
|
||||
|
||||
state = hass.states.get(ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
context = Context()
|
||||
audio_stream = object()
|
||||
|
||||
entity.async_set_context(context)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
|
||||
) as mock_start_pipeline:
|
||||
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||
|
||||
assert mock_start_pipeline.called
|
||||
kwargs = mock_start_pipeline.call_args[1]
|
||||
assert kwargs["context"] is context
|
||||
assert kwargs["event_callback"] == entity._internal_on_pipeline_event
|
||||
assert kwargs["stt_metadata"] == stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
assert kwargs["stt_stream"] is audio_stream
|
||||
assert kwargs["pipeline_id"] is None
|
||||
assert kwargs["device_id"] is None
|
||||
assert kwargs["tts_audio_output"] == "wav"
|
||||
assert kwargs["wake_word_phrase"] is None
|
||||
assert kwargs["audio_settings"] == AudioSettings(
|
||||
silence_seconds=vad.VadSensitivity.to_seconds(vad.VadSensitivity.DEFAULT)
|
||||
)
|
||||
assert kwargs["start_stage"] == PipelineStage.STT
|
||||
assert kwargs["end_stage"] == PipelineStage.TTS
|
||||
|
||||
for event_type, expected_state in (
|
||||
(PipelineEventType.RUN_START, STATE_UNKNOWN),
|
||||
(PipelineEventType.RUN_END, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||
(PipelineEventType.WAKE_WORD_START, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||
(PipelineEventType.WAKE_WORD_END, AssistSatelliteState.LISTENING_WAKE_WORD),
|
||||
(PipelineEventType.STT_START, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.STT_VAD_START, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.STT_VAD_END, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.STT_END, AssistSatelliteState.LISTENING_COMMAND),
|
||||
(PipelineEventType.INTENT_START, AssistSatelliteState.PROCESSING),
|
||||
(PipelineEventType.INTENT_END, AssistSatelliteState.PROCESSING),
|
||||
(PipelineEventType.TTS_START, AssistSatelliteState.RESPONDING),
|
||||
(PipelineEventType.TTS_END, AssistSatelliteState.RESPONDING),
|
||||
(PipelineEventType.ERROR, AssistSatelliteState.RESPONDING),
|
||||
):
|
||||
kwargs["event_callback"](PipelineEvent(event_type, {}))
|
||||
state = hass.states.get(ENTITY_ID)
|
||||
assert state.state == expected_state, event_type
|
||||
|
||||
entity.tts_response_finished()
|
||||
state = hass.states.get(ENTITY_ID)
|
||||
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("service_data", "expected_params"),
|
||||
[
|
||||
(
|
||||
{"message": "Hello"},
|
||||
("Hello", "https://www.home-assistant.io/resolved.mp3"),
|
||||
),
|
||||
(
|
||||
{
|
||||
"message": "Hello",
|
||||
"media_id": "http://example.com/bla.mp3",
|
||||
},
|
||||
("Hello", "http://example.com/bla.mp3"),
|
||||
),
|
||||
(
|
||||
{"media_id": "http://example.com/bla.mp3"},
|
||||
("", "http://example.com/bla.mp3"),
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_announce(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
service_data: dict,
|
||||
expected_params: tuple[str, str],
|
||||
) -> None:
|
||||
"""Test announcing on a device."""
|
||||
await async_update_pipeline(
|
||||
hass,
|
||||
async_get_pipeline(hass),
|
||||
tts_engine="tts.mock_entity",
|
||||
tts_language="en",
|
||||
tts_voice="test-voice",
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
|
||||
return_value="media-source://bla",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
return_value=PlayMedia(
|
||||
url="https://www.home-assistant.io/resolved.mp3",
|
||||
mime_type="audio/mp3",
|
||||
),
|
||||
),
|
||||
):
|
||||
await hass.services.async_call(
|
||||
"assist_satellite",
|
||||
"announce",
|
||||
service_data,
|
||||
target={"entity_id": "assist_satellite.test_entity"},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
assert entity.announcements[0] == expected_params
|
||||
|
||||
|
||||
async def test_announce_busy(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
) -> None:
|
||||
"""Test that announcing while an announcement is in progress raises an error."""
|
||||
media_id = "https://www.home-assistant.io/resolved.mp3"
|
||||
announce_started = asyncio.Event()
|
||||
got_error = asyncio.Event()
|
||||
|
||||
async def async_announce(message, media_id):
|
||||
announce_started.set()
|
||||
|
||||
# Block so we can do another announcement
|
||||
await got_error.wait()
|
||||
|
||||
with patch.object(entity, "async_announce", new=async_announce):
|
||||
announce_task = asyncio.create_task(
|
||||
entity.async_internal_announce(media_id=media_id)
|
||||
)
|
||||
async with asyncio.timeout(1):
|
||||
await announce_started.wait()
|
||||
|
||||
# Try to do a second announcement
|
||||
with pytest.raises(SatelliteBusyError):
|
||||
await entity.async_internal_announce(media_id=media_id)
|
||||
|
||||
# Avoid lingering task
|
||||
got_error.set()
|
||||
await announce_task
|
||||
|
||||
|
||||
async def test_context_refresh(
|
||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
"""Test that the context will be automatically refreshed."""
|
||||
audio_stream = object()
|
||||
|
||||
# Remove context
|
||||
entity._context = None
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream"
|
||||
):
|
||||
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||
|
||||
# Context should have been refreshed
|
||||
assert entity._context is not None
|
||||
|
||||
|
||||
async def test_pipeline_entity(
|
||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
"""Test getting pipeline from an entity."""
|
||||
audio_stream = object()
|
||||
pipeline = Pipeline(
|
||||
conversation_engine="test",
|
||||
conversation_language="en",
|
||||
language="en",
|
||||
name="test-pipeline",
|
||||
stt_engine=None,
|
||||
stt_language=None,
|
||||
tts_engine=None,
|
||||
tts_language=None,
|
||||
tts_voice=None,
|
||||
wake_word_entity=None,
|
||||
wake_word_id=None,
|
||||
)
|
||||
|
||||
pipeline_entity_id = "select.pipeline"
|
||||
hass.states.async_set(pipeline_entity_id, pipeline.name)
|
||||
entity._attr_pipeline_entity_id = pipeline_entity_id
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, pipeline_id: str, **kwargs):
|
||||
assert pipeline_id == pipeline.id
|
||||
done.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_get_pipelines",
|
||||
return_value=[pipeline],
|
||||
),
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||
await done.wait()
|
||||
|
||||
|
||||
async def test_pipeline_entity_preferred(
|
||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
"""Test getting pipeline from an entity with a preferred state."""
|
||||
audio_stream = object()
|
||||
|
||||
pipeline_entity_id = "select.pipeline"
|
||||
hass.states.async_set(pipeline_entity_id, OPTION_PREFERRED)
|
||||
entity._attr_pipeline_entity_id = pipeline_entity_id
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, pipeline_id: str, **kwargs):
|
||||
# Preferred pipeline
|
||||
assert pipeline_id is None
|
||||
done.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||
await done.wait()
|
||||
|
||||
|
||||
async def test_vad_sensitivity_entity(
|
||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
"""Test getting vad sensitivity from an entity."""
|
||||
audio_stream = object()
|
||||
|
||||
vad_sensitivity_entity_id = "select.vad_sensitivity"
|
||||
hass.states.async_set(vad_sensitivity_entity_id, vad.VadSensitivity.AGGRESSIVE)
|
||||
entity._attr_vad_sensitivity_entity_id = vad_sensitivity_entity_id
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(
|
||||
*args, audio_settings: AudioSettings, **kwargs
|
||||
):
|
||||
# Verify vad sensitivity
|
||||
assert audio_settings.silence_seconds == vad.VadSensitivity.to_seconds(
|
||||
vad.VadSensitivity.AGGRESSIVE
|
||||
)
|
||||
done.set()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
async with asyncio.timeout(1):
|
||||
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||
await done.wait()
|
||||
|
||||
|
||||
async def test_pipeline_entity_not_found(
|
||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
"""Test that setting the pipeline entity id to a non-existent entity raises an error."""
|
||||
audio_stream = object()
|
||||
|
||||
# Set to an entity that doesn't exist
|
||||
entity._attr_pipeline_entity_id = "select.pipeline"
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
||||
|
||||
|
||||
async def test_vad_sensitivity_entity_not_found(
|
||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
"""Test that setting the vad sensitivity entity id to a non-existent entity raises an error."""
|
||||
audio_stream = object()
|
||||
|
||||
# Set to an entity that doesn't exist
|
||||
entity._attr_vad_sensitivity_entity_id = "select.vad_sensitivity"
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await entity.async_accept_pipeline_from_satellite(audio_stream)
|
192
tests/components/assist_satellite/test_websocket_api.py
Normal file
192
tests/components/assist_satellite/test_websocket_api.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""Test WebSocket API."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from homeassistant.components.assist_pipeline import PipelineStage
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from . import ENTITY_ID
|
||||
from .conftest import MockAssistSatellite
|
||||
|
||||
from tests.common import MockUser
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
async def test_intercept_wake_word(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test intercepting a wake word."""
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await entity.async_accept_pipeline_from_satellite(
|
||||
object(),
|
||||
start_stage=PipelineStage.STT,
|
||||
wake_word_phrase="ok, nabu",
|
||||
)
|
||||
|
||||
response = await ws_client.receive_json()
|
||||
|
||||
assert response["success"]
|
||||
assert response["result"] == {"wake_word_phrase": "ok, nabu"}
|
||||
|
||||
|
||||
async def test_intercept_wake_word_requires_on_device_wake_word(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test intercepting a wake word fails if detection happens in HA."""
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await entity.async_accept_pipeline_from_satellite(
|
||||
object(),
|
||||
# Emulate wake word processing in Home Assistant
|
||||
start_stage=PipelineStage.WAKE_WORD,
|
||||
)
|
||||
|
||||
response = await ws_client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
"code": "home_assistant_error",
|
||||
"message": "Only on-device wake words currently supported",
|
||||
}
|
||||
|
||||
|
||||
async def test_intercept_wake_word_requires_wake_word_phrase(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test intercepting a wake word fails if detection happens in HA."""
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await entity.async_accept_pipeline_from_satellite(
|
||||
object(),
|
||||
start_stage=PipelineStage.STT,
|
||||
# We are not passing wake word phrase
|
||||
)
|
||||
|
||||
response = await ws_client.receive_json()
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
"code": "home_assistant_error",
|
||||
"message": "No wake word phrase provided",
|
||||
}
|
||||
|
||||
|
||||
async def test_intercept_wake_word_require_admin(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
hass_admin_user: MockUser,
|
||||
) -> None:
|
||||
"""Test intercepting a wake word requires admin access."""
|
||||
# Remove admin permission and verify we're not allowed
|
||||
hass_admin_user.groups = []
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
response = await ws_client.receive_json()
|
||||
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
"code": "unauthorized",
|
||||
"message": "Unauthorized",
|
||||
}
|
||||
|
||||
|
||||
async def test_intercept_wake_word_invalid_satellite(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test intercepting a wake word requires admin access."""
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": "assist_satellite.invalid",
|
||||
}
|
||||
)
|
||||
response = await ws_client.receive_json()
|
||||
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
"code": "not_found",
|
||||
"message": "Entity not found",
|
||||
}
|
||||
|
||||
|
||||
async def test_intercept_wake_word_twice(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
) -> None:
|
||||
"""Test intercepting a wake word requires admin access."""
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
|
||||
await ws_client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_satellite/intercept_wake_word",
|
||||
"entity_id": ENTITY_ID,
|
||||
}
|
||||
)
|
||||
response = await ws_client.receive_json()
|
||||
|
||||
assert not response["success"]
|
||||
assert response["error"] == {
|
||||
"code": "home_assistant_error",
|
||||
"message": "Wake word interception already in progress",
|
||||
}
|
Reference in New Issue
Block a user