Add ask_question action to Assist satellite (#145233)

* Add get_response to Assist satellite and ESPHome

* Rename get_response to ask_question

* Add possible answers to questions

* Add wildcard support and entity test

* Add ESPHome test

* Refactor to remove async_ask_question

* Use single entity_id instead of target

* Fix error message

* Remove ESPHome test

* Clean up

* Revert fix
This commit is contained in:
Michael Hansen
2025-06-19 16:50:14 -05:00
committed by GitHub
parent 2c13c70e12
commit 341d9f15f0
9 changed files with 447 additions and 3 deletions

View File

@ -1,13 +1,23 @@
"""Base class for assist satellite entities."""
from dataclasses import asdict
import logging
from pathlib import Path
from typing import Any
from hassil.util import (
PUNCTUATION_END,
PUNCTUATION_END_WORD,
PUNCTUATION_START,
PUNCTUATION_START_WORD,
)
import voluptuous as vol
from homeassistant.components.http import StaticPathConfig
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.core import HomeAssistant, ServiceCall, SupportsResponse
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.typing import ConfigType
@ -23,6 +33,7 @@ from .const import (
)
from .entity import (
AssistSatelliteAnnouncement,
AssistSatelliteAnswer,
AssistSatelliteConfiguration,
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
@ -34,6 +45,7 @@ from .websocket_api import async_register_websocket_api
__all__ = [
"DOMAIN",
"AssistSatelliteAnnouncement",
"AssistSatelliteAnswer",
"AssistSatelliteConfiguration",
"AssistSatelliteEntity",
"AssistSatelliteEntityDescription",
@ -86,6 +98,62 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"async_internal_start_conversation",
[AssistSatelliteEntityFeature.START_CONVERSATION],
)
async def handle_ask_question(call: ServiceCall) -> dict[str, Any]:
"""Handle a Show View service call."""
satellite_entity_id: str = call.data[ATTR_ENTITY_ID]
satellite_entity: AssistSatelliteEntity | None = component.get_entity(
satellite_entity_id
)
if satellite_entity is None:
raise HomeAssistantError(
f"Invalid Assist satellite entity id: {satellite_entity_id}"
)
ask_question_args = {
"question": call.data.get("question"),
"question_media_id": call.data.get("question_media_id"),
"preannounce": call.data.get("preannounce", False),
"answers": call.data.get("answers"),
}
if preannounce_media_id := call.data.get("preannounce_media_id"):
ask_question_args["preannounce_media_id"] = preannounce_media_id
answer = await satellite_entity.async_internal_ask_question(**ask_question_args)
if answer is None:
raise HomeAssistantError("No answer from satellite")
return asdict(answer)
hass.services.async_register(
domain=DOMAIN,
service="ask_question",
service_func=handle_ask_question,
schema=vol.All(
{
vol.Required(ATTR_ENTITY_ID): cv.entity_id,
vol.Optional("question"): str,
vol.Optional("question_media_id"): str,
vol.Optional("preannounce"): bool,
vol.Optional("preannounce_media_id"): str,
vol.Optional("answers"): [
{
vol.Required("id"): str,
vol.Required("sentences"): vol.All(
cv.ensure_list,
[cv.string],
has_one_non_empty_item,
has_no_punctuation,
),
}
],
},
cv.has_at_least_one_key("question", "question_media_id"),
),
supports_response=SupportsResponse.ONLY,
)
hass.data[CONNECTION_TEST_DATA] = {}
async_register_websocket_api(hass)
hass.http.register_view(ConnectionTestView())
@ -110,3 +178,29 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
def has_no_punctuation(value: list[str]) -> list[str]:
"""Validate result does not contain punctuation."""
for sentence in value:
if (
PUNCTUATION_START.search(sentence)
or PUNCTUATION_END.search(sentence)
or PUNCTUATION_START_WORD.search(sentence)
or PUNCTUATION_END_WORD.search(sentence)
):
raise vol.Invalid("sentence should not contain punctuation")
return value
def has_one_non_empty_item(value: list[str]) -> list[str]:
"""Validate result has at least one item."""
if len(value) < 1:
raise vol.Invalid("at least one sentence is required")
for sentence in value:
if not sentence:
raise vol.Invalid("sentences cannot be empty")
return value

View File

@ -4,12 +4,16 @@ from abc import abstractmethod
import asyncio
from collections.abc import AsyncIterable
import contextlib
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import StrEnum
import logging
import time
from typing import Any, Literal, final
from hassil import Intents, recognize
from hassil.expression import Expression, ListReference, Sequence
from hassil.intents import WildcardSlotList
from homeassistant.components import conversation, media_source, stt, tts
from homeassistant.components.assist_pipeline import (
OPTION_PREFERRED,
@ -105,6 +109,20 @@ class AssistSatelliteAnnouncement:
"""Media ID to be played before announcement."""
@dataclass
class AssistSatelliteAnswer:
"""Answer to a question."""
id: str | None
"""Matched answer id or None if no answer was matched."""
sentence: str
"""Raw sentence text from user response."""
slots: dict[str, Any] = field(default_factory=dict)
"""Matched slots from answer."""
class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite."""
@ -120,8 +138,10 @@ class AssistSatelliteEntity(entity.Entity):
_is_announcing = False
_extra_system_prompt: str | None = None
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
_stt_intercept_future: asyncio.Future[str | None] | None = None
_attr_tts_options: dict[str, Any] | None = None
_pipeline_task: asyncio.Task | None = None
_ask_question_future: asyncio.Future[str | None] | None = None
__assist_satellite_state = AssistSatelliteState.IDLE
@ -309,6 +329,112 @@ class AssistSatelliteEntity(entity.Entity):
"""Start a conversation from the satellite."""
raise NotImplementedError
async def async_internal_ask_question(
self,
question: str | None = None,
question_media_id: str | None = None,
preannounce: bool = True,
preannounce_media_id: str = PREANNOUNCE_URL,
answers: list[dict[str, Any]] | None = None,
) -> AssistSatelliteAnswer | None:
"""Ask a question and get a user's response from the satellite.
If question_media_id is not provided, question is synthesized to audio
with the selected pipeline.
If question_media_id is provided, it is played directly. It is possible
to omit the message and the satellite will not show any text.
If preannounce is True, a sound is played before the start message or media.
If preannounce_media_id is provided, it overrides the default sound.
Calls async_start_conversation.
"""
await self._cancel_running_pipeline()
if question is None:
question = ""
announcement = await self._resolve_announcement_media_id(
question,
question_media_id,
preannounce_media_id=preannounce_media_id if preannounce else None,
)
if self._is_announcing:
raise SatelliteBusyError
self._is_announcing = True
self._set_state(AssistSatelliteState.RESPONDING)
self._ask_question_future = asyncio.Future()
try:
# Wait for announcement to finish
await self.async_start_conversation(announcement)
# Wait for response text
response_text = await self._ask_question_future
if response_text is None:
raise HomeAssistantError("No answer from question")
if not answers:
return AssistSatelliteAnswer(id=None, sentence=response_text)
return self._question_response_to_answer(response_text, answers)
finally:
self._is_announcing = False
self._set_state(AssistSatelliteState.IDLE)
self._ask_question_future = None
def _question_response_to_answer(
self, response_text: str, answers: list[dict[str, Any]]
) -> AssistSatelliteAnswer:
"""Match text to a pre-defined set of answers."""
# Build intents and match
intents = Intents.from_dict(
{
"language": self.hass.config.language,
"intents": {
"QuestionIntent": {
"data": [
{
"sentences": answer["sentences"],
"metadata": {"answer_id": answer["id"]},
}
for answer in answers
]
}
},
}
)
# Assume slot list references are wildcards
wildcard_names: set[str] = set()
for intent in intents.intents.values():
for intent_data in intent.data:
for sentence in intent_data.sentences:
_collect_list_references(sentence, wildcard_names)
for wildcard_name in wildcard_names:
intents.slot_lists[wildcard_name] = WildcardSlotList(wildcard_name)
# Match response text
result = recognize(response_text, intents)
if result is None:
# No match
return AssistSatelliteAnswer(id=None, sentence=response_text)
assert result.intent_metadata
return AssistSatelliteAnswer(
id=result.intent_metadata["answer_id"],
sentence=response_text,
slots={
entity_name: entity.value
for entity_name, entity in result.entities.items()
},
)
async def async_accept_pipeline_from_satellite(
self,
audio_stream: AsyncIterable[bytes],
@ -351,6 +477,11 @@ class AssistSatelliteEntity(entity.Entity):
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
return
if (self._ask_question_future is not None) and (
start_stage == PipelineStage.STT
):
end_stage = PipelineStage.STT
device_id = self.registry_entry.device_id if self.registry_entry else None
# Refresh context if necessary
@ -433,6 +564,16 @@ class AssistSatelliteEntity(entity.Entity):
self._set_state(AssistSatelliteState.IDLE)
elif event.type is PipelineEventType.STT_START:
self._set_state(AssistSatelliteState.LISTENING)
elif event.type is PipelineEventType.STT_END:
# Intercepting text for ask question
if (
(self._ask_question_future is not None)
and (not self._ask_question_future.done())
and event.data
):
self._ask_question_future.set_result(
event.data.get("stt_output", {}).get("text")
)
elif event.type is PipelineEventType.INTENT_START:
self._set_state(AssistSatelliteState.PROCESSING)
elif event.type is PipelineEventType.TTS_START:
@ -443,6 +584,12 @@ class AssistSatelliteEntity(entity.Entity):
if not self._run_has_tts:
self._set_state(AssistSatelliteState.IDLE)
if (self._ask_question_future is not None) and (
not self._ask_question_future.done()
):
# No text for ask question
self._ask_question_future.set_result(None)
self.on_pipeline_event(event)
@callback
@ -577,3 +724,15 @@ class AssistSatelliteEntity(entity.Entity):
media_id_source=media_id_source,
preannounce_media_id=preannounce_media_id,
)
def _collect_list_references(expression: Expression, list_names: set[str]) -> None:
"""Collect list reference names recursively."""
if isinstance(expression, Sequence):
seq: Sequence = expression
for item in seq.items:
_collect_list_references(item, list_names)
elif isinstance(expression, ListReference):
# {list}
list_ref: ListReference = expression
list_names.add(list_ref.slot_name)

View File

@ -10,6 +10,9 @@
},
"start_conversation": {
"service": "mdi:forum"
},
"ask_question": {
"service": "mdi:microphone-question"
}
}
}

View File

@ -5,5 +5,6 @@
"dependencies": ["assist_pipeline", "http", "stt", "tts"],
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
"integration_type": "entity",
"quality_scale": "internal"
"quality_scale": "internal",
"requirements": ["hassil==2.2.3"]
}

View File

@ -54,3 +54,35 @@ start_conversation:
required: false
selector:
text:
ask_question:
fields:
entity_id:
required: true
selector:
entity:
domain: assist_satellite
supported_features:
- assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
question:
required: false
example: "What kind of music would you like to play?"
default: ""
selector:
text:
question_media_id:
required: false
selector:
text:
preannounce:
required: false
default: true
selector:
boolean:
preannounce_media_id:
required: false
selector:
text:
answers:
required: false
selector:
object:

View File

@ -59,6 +59,36 @@
"description": "Custom media ID to play before the start message or media."
}
}
},
"ask_question": {
"name": "Ask question",
"description": "Asks a question and gets the user's response.",
"fields": {
"entity_id": {
"name": "Entity",
"description": "Assist satellite entity to ask the question on."
},
"question": {
"name": "Question",
"description": "The question to ask."
},
"question_media_id": {
"name": "Question media ID",
"description": "The media ID of the question to use instead of text-to-speech."
},
"preannounce": {
"name": "Preannounce",
"description": "Play a sound before the start message or media."
},
"preannounce_media_id": {
"name": "Preannounce media ID",
"description": "Custom media ID to play before the start message or media."
},
"answers": {
"name": "Answers",
"description": "Possible answers to the question."
}
}
}
}
}

1
requirements_all.txt generated
View File

@ -1129,6 +1129,7 @@ hass-nabucasa==0.102.0
# homeassistant.components.splunk
hass-splunk==0.1.1
# homeassistant.components.assist_satellite
# homeassistant.components.conversation
hassil==2.2.3

View File

@ -984,6 +984,7 @@ habluetooth==3.49.0
# homeassistant.components.cloud
hass-nabucasa==0.102.0
# homeassistant.components.assist_satellite
# homeassistant.components.conversation
hassil==2.2.3

View File

@ -2,6 +2,7 @@
import asyncio
from collections.abc import Generator
from dataclasses import asdict
from unittest.mock import Mock, patch
import pytest
@ -20,6 +21,7 @@ from homeassistant.components.assist_pipeline import (
)
from homeassistant.components.assist_satellite import (
AssistSatelliteAnnouncement,
AssistSatelliteAnswer,
SatelliteBusyError,
)
from homeassistant.components.assist_satellite.const import PREANNOUNCE_URL
@ -708,6 +710,127 @@ async def test_start_conversation_default_preannounce(
)
@pytest.mark.parametrize(
("service_data", "response_text", "expected_answer"),
[
(
{"preannounce": False},
"jazz",
AssistSatelliteAnswer(id=None, sentence="jazz"),
),
(
{
"answers": [
{"id": "jazz", "sentences": ["[some] jazz [please]"]},
{"id": "rock", "sentences": ["[some] rock [please]"]},
],
"preannounce": False,
},
"Some Rock, please.",
AssistSatelliteAnswer(id="rock", sentence="Some Rock, please."),
),
(
{
"answers": [
{
"id": "genre",
"sentences": ["genre {genre} [please]"],
},
{
"id": "artist",
"sentences": ["artist {artist} [please]"],
},
],
"preannounce": False,
},
"artist Pink Floyd",
AssistSatelliteAnswer(
id="artist",
sentence="artist Pink Floyd",
slots={"artist": "Pink Floyd"},
),
),
],
)
async def test_ask_question(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
service_data: dict,
response_text: str,
expected_answer: AssistSatelliteAnswer,
) -> None:
"""Test asking a question on a device and matching an answer."""
entity_id = "assist_satellite.test_entity"
question_text = "What kind of music would you like to listen to?"
await async_update_pipeline(
hass, async_get_pipeline(hass), stt_engine="test-stt-engine", stt_language="en"
)
async def speech_to_text(self, *args, **kwargs):
self.process_event(
PipelineEvent(
PipelineEventType.STT_END, {"stt_output": {"text": response_text}}
)
)
return response_text
original_start_conversation = entity.async_start_conversation
async def async_start_conversation(start_announcement):
# Verify state change
assert entity.state == AssistSatelliteState.RESPONDING
await original_start_conversation(start_announcement)
audio_stream = object()
with (
patch(
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_speech_to_text"
),
patch(
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.speech_to_text",
speech_to_text,
),
):
await entity.async_accept_pipeline_from_satellite(
audio_stream, start_stage=PipelineStage.STT
)
with (
patch(
"homeassistant.components.tts.generate_media_source_id",
return_value="media-source://generated",
),
patch(
"homeassistant.components.tts.async_resolve_engine",
return_value="tts.cloud",
),
patch(
"homeassistant.components.tts.async_create_stream",
return_value=MockResultStream(hass, "wav", b""),
),
patch(
"homeassistant.components.media_source.async_resolve_media",
return_value=PlayMedia(
url="https://www.home-assistant.io/resolved.mp3",
mime_type="audio/mp3",
),
),
patch.object(entity, "async_start_conversation", new=async_start_conversation),
):
response = await hass.services.async_call(
"assist_satellite",
"ask_question",
{"entity_id": entity_id, "question": question_text, **service_data},
blocking=True,
return_response=True,
)
assert entity.state == AssistSatelliteState.IDLE
assert response == asdict(expected_answer)
async def test_wake_word_start_keeps_responding(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None: