mirror of
https://github.com/home-assistant/core.git
synced 2025-06-25 01:21:51 +02:00
Add ask_question action to Assist satellite (#145233)
* Add get_response to Assist satellite and ESPHome * Rename get_response to ask_question * Add possible answers to questions * Add wildcard support and entity test * Add ESPHome test * Refactor to remove async_ask_question * Use single entity_id instead of target * Fix error message * Remove ESPHome test * Clean up * Revert fix
This commit is contained in:
@ -1,13 +1,23 @@
|
||||
"""Base class for assist satellite entities."""
|
||||
|
||||
from dataclasses import asdict
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from hassil.util import (
|
||||
PUNCTUATION_END,
|
||||
PUNCTUATION_END_WORD,
|
||||
PUNCTUATION_START,
|
||||
PUNCTUATION_START_WORD,
|
||||
)
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.http import StaticPathConfig
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.const import ATTR_ENTITY_ID
|
||||
from homeassistant.core import HomeAssistant, ServiceCall, SupportsResponse
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
@ -23,6 +33,7 @@ from .const import (
|
||||
)
|
||||
from .entity import (
|
||||
AssistSatelliteAnnouncement,
|
||||
AssistSatelliteAnswer,
|
||||
AssistSatelliteConfiguration,
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityDescription,
|
||||
@ -34,6 +45,7 @@ from .websocket_api import async_register_websocket_api
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"AssistSatelliteAnnouncement",
|
||||
"AssistSatelliteAnswer",
|
||||
"AssistSatelliteConfiguration",
|
||||
"AssistSatelliteEntity",
|
||||
"AssistSatelliteEntityDescription",
|
||||
@ -86,6 +98,62 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"async_internal_start_conversation",
|
||||
[AssistSatelliteEntityFeature.START_CONVERSATION],
|
||||
)
|
||||
|
||||
async def handle_ask_question(call: ServiceCall) -> dict[str, Any]:
|
||||
"""Handle a Show View service call."""
|
||||
satellite_entity_id: str = call.data[ATTR_ENTITY_ID]
|
||||
satellite_entity: AssistSatelliteEntity | None = component.get_entity(
|
||||
satellite_entity_id
|
||||
)
|
||||
if satellite_entity is None:
|
||||
raise HomeAssistantError(
|
||||
f"Invalid Assist satellite entity id: {satellite_entity_id}"
|
||||
)
|
||||
|
||||
ask_question_args = {
|
||||
"question": call.data.get("question"),
|
||||
"question_media_id": call.data.get("question_media_id"),
|
||||
"preannounce": call.data.get("preannounce", False),
|
||||
"answers": call.data.get("answers"),
|
||||
}
|
||||
|
||||
if preannounce_media_id := call.data.get("preannounce_media_id"):
|
||||
ask_question_args["preannounce_media_id"] = preannounce_media_id
|
||||
|
||||
answer = await satellite_entity.async_internal_ask_question(**ask_question_args)
|
||||
|
||||
if answer is None:
|
||||
raise HomeAssistantError("No answer from satellite")
|
||||
|
||||
return asdict(answer)
|
||||
|
||||
hass.services.async_register(
|
||||
domain=DOMAIN,
|
||||
service="ask_question",
|
||||
service_func=handle_ask_question,
|
||||
schema=vol.All(
|
||||
{
|
||||
vol.Required(ATTR_ENTITY_ID): cv.entity_id,
|
||||
vol.Optional("question"): str,
|
||||
vol.Optional("question_media_id"): str,
|
||||
vol.Optional("preannounce"): bool,
|
||||
vol.Optional("preannounce_media_id"): str,
|
||||
vol.Optional("answers"): [
|
||||
{
|
||||
vol.Required("id"): str,
|
||||
vol.Required("sentences"): vol.All(
|
||||
cv.ensure_list,
|
||||
[cv.string],
|
||||
has_one_non_empty_item,
|
||||
has_no_punctuation,
|
||||
),
|
||||
}
|
||||
],
|
||||
},
|
||||
cv.has_at_least_one_key("question", "question_media_id"),
|
||||
),
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
)
|
||||
hass.data[CONNECTION_TEST_DATA] = {}
|
||||
async_register_websocket_api(hass)
|
||||
hass.http.register_view(ConnectionTestView())
|
||||
@ -110,3 +178,29 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
|
||||
|
||||
|
||||
def has_no_punctuation(value: list[str]) -> list[str]:
|
||||
"""Validate result does not contain punctuation."""
|
||||
for sentence in value:
|
||||
if (
|
||||
PUNCTUATION_START.search(sentence)
|
||||
or PUNCTUATION_END.search(sentence)
|
||||
or PUNCTUATION_START_WORD.search(sentence)
|
||||
or PUNCTUATION_END_WORD.search(sentence)
|
||||
):
|
||||
raise vol.Invalid("sentence should not contain punctuation")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def has_one_non_empty_item(value: list[str]) -> list[str]:
|
||||
"""Validate result has at least one item."""
|
||||
if len(value) < 1:
|
||||
raise vol.Invalid("at least one sentence is required")
|
||||
|
||||
for sentence in value:
|
||||
if not sentence:
|
||||
raise vol.Invalid("sentences cannot be empty")
|
||||
|
||||
return value
|
||||
|
@ -4,12 +4,16 @@ from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Literal, final
|
||||
|
||||
from hassil import Intents, recognize
|
||||
from hassil.expression import Expression, ListReference, Sequence
|
||||
from hassil.intents import WildcardSlotList
|
||||
|
||||
from homeassistant.components import conversation, media_source, stt, tts
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
OPTION_PREFERRED,
|
||||
@ -105,6 +109,20 @@ class AssistSatelliteAnnouncement:
|
||||
"""Media ID to be played before announcement."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistSatelliteAnswer:
|
||||
"""Answer to a question."""
|
||||
|
||||
id: str | None
|
||||
"""Matched answer id or None if no answer was matched."""
|
||||
|
||||
sentence: str
|
||||
"""Raw sentence text from user response."""
|
||||
|
||||
slots: dict[str, Any] = field(default_factory=dict)
|
||||
"""Matched slots from answer."""
|
||||
|
||||
|
||||
class AssistSatelliteEntity(entity.Entity):
|
||||
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
||||
|
||||
@ -120,8 +138,10 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
_is_announcing = False
|
||||
_extra_system_prompt: str | None = None
|
||||
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
|
||||
_stt_intercept_future: asyncio.Future[str | None] | None = None
|
||||
_attr_tts_options: dict[str, Any] | None = None
|
||||
_pipeline_task: asyncio.Task | None = None
|
||||
_ask_question_future: asyncio.Future[str | None] | None = None
|
||||
|
||||
__assist_satellite_state = AssistSatelliteState.IDLE
|
||||
|
||||
@ -309,6 +329,112 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
"""Start a conversation from the satellite."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_internal_ask_question(
|
||||
self,
|
||||
question: str | None = None,
|
||||
question_media_id: str | None = None,
|
||||
preannounce: bool = True,
|
||||
preannounce_media_id: str = PREANNOUNCE_URL,
|
||||
answers: list[dict[str, Any]] | None = None,
|
||||
) -> AssistSatelliteAnswer | None:
|
||||
"""Ask a question and get a user's response from the satellite.
|
||||
|
||||
If question_media_id is not provided, question is synthesized to audio
|
||||
with the selected pipeline.
|
||||
|
||||
If question_media_id is provided, it is played directly. It is possible
|
||||
to omit the message and the satellite will not show any text.
|
||||
|
||||
If preannounce is True, a sound is played before the start message or media.
|
||||
If preannounce_media_id is provided, it overrides the default sound.
|
||||
|
||||
Calls async_start_conversation.
|
||||
"""
|
||||
await self._cancel_running_pipeline()
|
||||
|
||||
if question is None:
|
||||
question = ""
|
||||
|
||||
announcement = await self._resolve_announcement_media_id(
|
||||
question,
|
||||
question_media_id,
|
||||
preannounce_media_id=preannounce_media_id if preannounce else None,
|
||||
)
|
||||
|
||||
if self._is_announcing:
|
||||
raise SatelliteBusyError
|
||||
|
||||
self._is_announcing = True
|
||||
self._set_state(AssistSatelliteState.RESPONDING)
|
||||
self._ask_question_future = asyncio.Future()
|
||||
|
||||
try:
|
||||
# Wait for announcement to finish
|
||||
await self.async_start_conversation(announcement)
|
||||
|
||||
# Wait for response text
|
||||
response_text = await self._ask_question_future
|
||||
if response_text is None:
|
||||
raise HomeAssistantError("No answer from question")
|
||||
|
||||
if not answers:
|
||||
return AssistSatelliteAnswer(id=None, sentence=response_text)
|
||||
|
||||
return self._question_response_to_answer(response_text, answers)
|
||||
finally:
|
||||
self._is_announcing = False
|
||||
self._set_state(AssistSatelliteState.IDLE)
|
||||
self._ask_question_future = None
|
||||
|
||||
def _question_response_to_answer(
|
||||
self, response_text: str, answers: list[dict[str, Any]]
|
||||
) -> AssistSatelliteAnswer:
|
||||
"""Match text to a pre-defined set of answers."""
|
||||
|
||||
# Build intents and match
|
||||
intents = Intents.from_dict(
|
||||
{
|
||||
"language": self.hass.config.language,
|
||||
"intents": {
|
||||
"QuestionIntent": {
|
||||
"data": [
|
||||
{
|
||||
"sentences": answer["sentences"],
|
||||
"metadata": {"answer_id": answer["id"]},
|
||||
}
|
||||
for answer in answers
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Assume slot list references are wildcards
|
||||
wildcard_names: set[str] = set()
|
||||
for intent in intents.intents.values():
|
||||
for intent_data in intent.data:
|
||||
for sentence in intent_data.sentences:
|
||||
_collect_list_references(sentence, wildcard_names)
|
||||
|
||||
for wildcard_name in wildcard_names:
|
||||
intents.slot_lists[wildcard_name] = WildcardSlotList(wildcard_name)
|
||||
|
||||
# Match response text
|
||||
result = recognize(response_text, intents)
|
||||
if result is None:
|
||||
# No match
|
||||
return AssistSatelliteAnswer(id=None, sentence=response_text)
|
||||
|
||||
assert result.intent_metadata
|
||||
return AssistSatelliteAnswer(
|
||||
id=result.intent_metadata["answer_id"],
|
||||
sentence=response_text,
|
||||
slots={
|
||||
entity_name: entity.value
|
||||
for entity_name, entity in result.entities.items()
|
||||
},
|
||||
)
|
||||
|
||||
async def async_accept_pipeline_from_satellite(
|
||||
self,
|
||||
audio_stream: AsyncIterable[bytes],
|
||||
@ -351,6 +477,11 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
self._internal_on_pipeline_event(PipelineEvent(PipelineEventType.RUN_END))
|
||||
return
|
||||
|
||||
if (self._ask_question_future is not None) and (
|
||||
start_stage == PipelineStage.STT
|
||||
):
|
||||
end_stage = PipelineStage.STT
|
||||
|
||||
device_id = self.registry_entry.device_id if self.registry_entry else None
|
||||
|
||||
# Refresh context if necessary
|
||||
@ -433,6 +564,16 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
self._set_state(AssistSatelliteState.IDLE)
|
||||
elif event.type is PipelineEventType.STT_START:
|
||||
self._set_state(AssistSatelliteState.LISTENING)
|
||||
elif event.type is PipelineEventType.STT_END:
|
||||
# Intercepting text for ask question
|
||||
if (
|
||||
(self._ask_question_future is not None)
|
||||
and (not self._ask_question_future.done())
|
||||
and event.data
|
||||
):
|
||||
self._ask_question_future.set_result(
|
||||
event.data.get("stt_output", {}).get("text")
|
||||
)
|
||||
elif event.type is PipelineEventType.INTENT_START:
|
||||
self._set_state(AssistSatelliteState.PROCESSING)
|
||||
elif event.type is PipelineEventType.TTS_START:
|
||||
@ -443,6 +584,12 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
if not self._run_has_tts:
|
||||
self._set_state(AssistSatelliteState.IDLE)
|
||||
|
||||
if (self._ask_question_future is not None) and (
|
||||
not self._ask_question_future.done()
|
||||
):
|
||||
# No text for ask question
|
||||
self._ask_question_future.set_result(None)
|
||||
|
||||
self.on_pipeline_event(event)
|
||||
|
||||
@callback
|
||||
@ -577,3 +724,15 @@ class AssistSatelliteEntity(entity.Entity):
|
||||
media_id_source=media_id_source,
|
||||
preannounce_media_id=preannounce_media_id,
|
||||
)
|
||||
|
||||
|
||||
def _collect_list_references(expression: Expression, list_names: set[str]) -> None:
|
||||
"""Collect list reference names recursively."""
|
||||
if isinstance(expression, Sequence):
|
||||
seq: Sequence = expression
|
||||
for item in seq.items:
|
||||
_collect_list_references(item, list_names)
|
||||
elif isinstance(expression, ListReference):
|
||||
# {list}
|
||||
list_ref: ListReference = expression
|
||||
list_names.add(list_ref.slot_name)
|
||||
|
@ -10,6 +10,9 @@
|
||||
},
|
||||
"start_conversation": {
|
||||
"service": "mdi:forum"
|
||||
},
|
||||
"ask_question": {
|
||||
"service": "mdi:microphone-question"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,5 +5,6 @@
|
||||
"dependencies": ["assist_pipeline", "http", "stt", "tts"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/assist_satellite",
|
||||
"integration_type": "entity",
|
||||
"quality_scale": "internal"
|
||||
"quality_scale": "internal",
|
||||
"requirements": ["hassil==2.2.3"]
|
||||
}
|
||||
|
@ -54,3 +54,35 @@ start_conversation:
|
||||
required: false
|
||||
selector:
|
||||
text:
|
||||
ask_question:
|
||||
fields:
|
||||
entity_id:
|
||||
required: true
|
||||
selector:
|
||||
entity:
|
||||
domain: assist_satellite
|
||||
supported_features:
|
||||
- assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
|
||||
question:
|
||||
required: false
|
||||
example: "What kind of music would you like to play?"
|
||||
default: ""
|
||||
selector:
|
||||
text:
|
||||
question_media_id:
|
||||
required: false
|
||||
selector:
|
||||
text:
|
||||
preannounce:
|
||||
required: false
|
||||
default: true
|
||||
selector:
|
||||
boolean:
|
||||
preannounce_media_id:
|
||||
required: false
|
||||
selector:
|
||||
text:
|
||||
answers:
|
||||
required: false
|
||||
selector:
|
||||
object:
|
||||
|
@ -59,6 +59,36 @@
|
||||
"description": "Custom media ID to play before the start message or media."
|
||||
}
|
||||
}
|
||||
},
|
||||
"ask_question": {
|
||||
"name": "Ask question",
|
||||
"description": "Asks a question and gets the user's response.",
|
||||
"fields": {
|
||||
"entity_id": {
|
||||
"name": "Entity",
|
||||
"description": "Assist satellite entity to ask the question on."
|
||||
},
|
||||
"question": {
|
||||
"name": "Question",
|
||||
"description": "The question to ask."
|
||||
},
|
||||
"question_media_id": {
|
||||
"name": "Question media ID",
|
||||
"description": "The media ID of the question to use instead of text-to-speech."
|
||||
},
|
||||
"preannounce": {
|
||||
"name": "Preannounce",
|
||||
"description": "Play a sound before the start message or media."
|
||||
},
|
||||
"preannounce_media_id": {
|
||||
"name": "Preannounce media ID",
|
||||
"description": "Custom media ID to play before the start message or media."
|
||||
},
|
||||
"answers": {
|
||||
"name": "Answers",
|
||||
"description": "Possible answers to the question."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
1
requirements_all.txt
generated
1
requirements_all.txt
generated
@ -1129,6 +1129,7 @@ hass-nabucasa==0.102.0
|
||||
# homeassistant.components.splunk
|
||||
hass-splunk==0.1.1
|
||||
|
||||
# homeassistant.components.assist_satellite
|
||||
# homeassistant.components.conversation
|
||||
hassil==2.2.3
|
||||
|
||||
|
1
requirements_test_all.txt
generated
1
requirements_test_all.txt
generated
@ -984,6 +984,7 @@ habluetooth==3.49.0
|
||||
# homeassistant.components.cloud
|
||||
hass-nabucasa==0.102.0
|
||||
|
||||
# homeassistant.components.assist_satellite
|
||||
# homeassistant.components.conversation
|
||||
hassil==2.2.3
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Generator
|
||||
from dataclasses import asdict
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@ -20,6 +21,7 @@ from homeassistant.components.assist_pipeline import (
|
||||
)
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteAnnouncement,
|
||||
AssistSatelliteAnswer,
|
||||
SatelliteBusyError,
|
||||
)
|
||||
from homeassistant.components.assist_satellite.const import PREANNOUNCE_URL
|
||||
@ -708,6 +710,127 @@ async def test_start_conversation_default_preannounce(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("service_data", "response_text", "expected_answer"),
|
||||
[
|
||||
(
|
||||
{"preannounce": False},
|
||||
"jazz",
|
||||
AssistSatelliteAnswer(id=None, sentence="jazz"),
|
||||
),
|
||||
(
|
||||
{
|
||||
"answers": [
|
||||
{"id": "jazz", "sentences": ["[some] jazz [please]"]},
|
||||
{"id": "rock", "sentences": ["[some] rock [please]"]},
|
||||
],
|
||||
"preannounce": False,
|
||||
},
|
||||
"Some Rock, please.",
|
||||
AssistSatelliteAnswer(id="rock", sentence="Some Rock, please."),
|
||||
),
|
||||
(
|
||||
{
|
||||
"answers": [
|
||||
{
|
||||
"id": "genre",
|
||||
"sentences": ["genre {genre} [please]"],
|
||||
},
|
||||
{
|
||||
"id": "artist",
|
||||
"sentences": ["artist {artist} [please]"],
|
||||
},
|
||||
],
|
||||
"preannounce": False,
|
||||
},
|
||||
"artist Pink Floyd",
|
||||
AssistSatelliteAnswer(
|
||||
id="artist",
|
||||
sentence="artist Pink Floyd",
|
||||
slots={"artist": "Pink Floyd"},
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_ask_question(
|
||||
hass: HomeAssistant,
|
||||
init_components: ConfigEntry,
|
||||
entity: MockAssistSatellite,
|
||||
service_data: dict,
|
||||
response_text: str,
|
||||
expected_answer: AssistSatelliteAnswer,
|
||||
) -> None:
|
||||
"""Test asking a question on a device and matching an answer."""
|
||||
entity_id = "assist_satellite.test_entity"
|
||||
question_text = "What kind of music would you like to listen to?"
|
||||
|
||||
await async_update_pipeline(
|
||||
hass, async_get_pipeline(hass), stt_engine="test-stt-engine", stt_language="en"
|
||||
)
|
||||
|
||||
async def speech_to_text(self, *args, **kwargs):
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.STT_END, {"stt_output": {"text": response_text}}
|
||||
)
|
||||
)
|
||||
|
||||
return response_text
|
||||
|
||||
original_start_conversation = entity.async_start_conversation
|
||||
|
||||
async def async_start_conversation(start_announcement):
|
||||
# Verify state change
|
||||
assert entity.state == AssistSatelliteState.RESPONDING
|
||||
await original_start_conversation(start_announcement)
|
||||
|
||||
audio_stream = object()
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.prepare_speech_to_text"
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.PipelineRun.speech_to_text",
|
||||
speech_to_text,
|
||||
),
|
||||
):
|
||||
await entity.async_accept_pipeline_from_satellite(
|
||||
audio_stream, start_stage=PipelineStage.STT
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.tts.generate_media_source_id",
|
||||
return_value="media-source://generated",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_resolve_engine",
|
||||
return_value="tts.cloud",
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_create_stream",
|
||||
return_value=MockResultStream(hass, "wav", b""),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
return_value=PlayMedia(
|
||||
url="https://www.home-assistant.io/resolved.mp3",
|
||||
mime_type="audio/mp3",
|
||||
),
|
||||
),
|
||||
patch.object(entity, "async_start_conversation", new=async_start_conversation),
|
||||
):
|
||||
response = await hass.services.async_call(
|
||||
"assist_satellite",
|
||||
"ask_question",
|
||||
{"entity_id": entity_id, "question": question_text, **service_data},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
assert entity.state == AssistSatelliteState.IDLE
|
||||
assert response == asdict(expected_answer)
|
||||
|
||||
|
||||
async def test_wake_word_start_keeps_responding(
|
||||
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
|
||||
) -> None:
|
||||
|
Reference in New Issue
Block a user