Clean up and incorporate feedback

This commit is contained in:
Michael Hansen
2025-08-25 14:57:40 -05:00
parent ed8ffde8c3
commit cfc056752c
6 changed files with 55 additions and 67 deletions

View File

@@ -3,11 +3,9 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
from http import HTTPStatus
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from aiohttp import web
import voluptuous as vol import voluptuous as vol
from homeassistant.components import http, stt from homeassistant.components import http, stt
@@ -16,6 +14,8 @@ from homeassistant.helpers import chat_session
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import (
ACKNOWLEDGE_FILENAME,
ACKNOWLEDGE_URL,
CONF_DEBUG_RECORDING_DIR, CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG, DATA_CONFIG,
DATA_LAST_WAKE_UP, DATA_LAST_WAKE_UP,
@@ -89,7 +89,14 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
await async_setup_pipeline_store(hass) await async_setup_pipeline_store(hass)
async_register_websocket_api(hass) async_register_websocket_api(hass)
hass.http.register_view(DefaultSoundsView(hass)) await hass.http.async_register_static_paths(
[
http.StaticPathConfig(
ACKNOWLEDGE_URL,
str(Path(__file__).parent / ACKNOWLEDGE_FILENAME),
)
]
)
return True return True
@@ -138,25 +145,3 @@ async def async_pipeline_from_audio_stream(
) )
await pipeline_input.validate() await pipeline_input.validate()
await pipeline_input.execute() await pipeline_input.execute()
# -----------------------------------------------------------------------------
class DefaultSoundsView(http.HomeAssistantView):
"""HTTP view to host default sounds."""
url = f"/api/{DOMAIN}/sounds/{{filename}}"
name = f"api:{DOMAIN}:sounds"
requires_auth = False
def __init__(self, hass: HomeAssistant) -> None:
self.hass = hass
self.base_dir = Path(__file__).parent / "sounds"
async def get(self, request: web.Request, filename: str) -> web.StreamResponse:
"""Get data for sound file."""
if filename not in ("acknowledge.mp3",):
return web.Response(body="Invalid filename", status=HTTPStatus.BAD_REQUEST)
return web.FileResponse(self.base_dir / filename)

View File

@@ -23,3 +23,6 @@ SAMPLES_PER_CHUNK = SAMPLE_RATE // (1000 // MS_PER_CHUNK) # 10 ms @ 16Khz
BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit BYTES_PER_CHUNK = SAMPLES_PER_CHUNK * SAMPLE_WIDTH * SAMPLE_CHANNELS # 16-bit
OPTION_PREFERRED = "preferred" OPTION_PREFERRED = "preferred"
ACKNOWLEDGE_FILENAME = "acknowledge.mp3"
ACKNOWLEDGE_URL = f"/api/assist_pipeline/static/{ACKNOWLEDGE_FILENAME}"

View File

@@ -21,6 +21,7 @@ import voluptuous as vol
from homeassistant.components import ( from homeassistant.components import (
conversation, conversation,
media_player,
media_source, media_source,
stt, stt,
tts, tts,
@@ -35,7 +36,6 @@ from homeassistant.helpers import (
device_registry as dr, device_registry as dr,
entity_registry as er, entity_registry as er,
intent, intent,
network,
) )
from homeassistant.helpers.collection import ( from homeassistant.helpers.collection import (
CHANGE_UPDATED, CHANGE_UPDATED,
@@ -58,6 +58,7 @@ from homeassistant.util.limited_size_dict import LimitedSizeDict
from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadSpeexEnhancer from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadSpeexEnhancer
from .const import ( from .const import (
ACKNOWLEDGE_URL,
BYTES_PER_CHUNK, BYTES_PER_CHUNK,
CONF_DEBUG_RECORDING_DIR, CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG, DATA_CONFIG,
@@ -104,8 +105,6 @@ KEY_PIPELINE_CONVERSATION_DATA: HassKey[dict[str, PipelineConversationData]] = H
# Number of response parts to handle before streaming the response # Number of response parts to handle before streaming the response
STREAM_RESPONSE_CHARS = 60 STREAM_RESPONSE_CHARS = 60
DEFAULT_ACKNOWLEDGE_MEDIA_ID = f"/api/{DOMAIN}/sounds/acknowledge.mp3"
def validate_language(data: dict[str, Any]) -> Any: def validate_language(data: dict[str, Any]) -> Any:
"""Validate language settings.""" """Validate language settings."""
@@ -456,9 +455,7 @@ class Pipeline:
wake_word_id=data["wake_word_id"], wake_word_id=data["wake_word_id"],
prefer_local_intents=data.get("prefer_local_intents", False), prefer_local_intents=data.get("prefer_local_intents", False),
acknowledge_same_area=data.get("acknowledge_same_area", True), acknowledge_same_area=data.get("acknowledge_same_area", True),
acknowledge_media_id=data.get( acknowledge_media_id=data.get("acknowledge_media_id", ACKNOWLEDGE_URL),
"acknowledge_media_id", DEFAULT_ACKNOWLEDGE_MEDIA_ID
),
) )
def to_json(self) -> dict[str, Any]: def to_json(self) -> dict[str, Any]:
@@ -1308,45 +1305,13 @@ class PipelineRun:
if tts_input_stream and self._streamed_response_text: if tts_input_stream and self._streamed_response_text:
tts_input_stream.put_nowait(None) tts_input_stream.put_nowait(None)
intent_response = conversation_result.response
device_registry = dr.async_get(self.hass)
# Check if all targeted entities were in the same area as # Check if all targeted entities were in the same area as
# the satellite device. # the satellite device.
# If so, the satellite can response with an acknowledge beep # If so, the satellite can response with an acknowledge beep
# instead of a full response. # instead of a full response.
if ( can_acknowledge = self._can_acknowledge_response(
( conversation_result.response, device_id
intent_response.response_type
== intent.IntentResponseType.ACTION_DONE
) )
and intent_response.matched_states
and device_id
and (device := device_registry.async_get(device_id))
and device.area_id
):
entity_registry = er.async_get(self.hass)
can_acknowledge = True
for state in intent_response.matched_states:
entity = entity_registry.async_get(state.entity_id)
if (
(not entity)
or (
entity.area_id
and (entity.area_id != device.area_id)
)
or (
entity.device_id
and (
entity_device := device_registry.async_get(
entity.device_id
)
)
and entity_device.area_id != device.area_id
)
):
can_acknowledge = False
break
except Exception as src_error: except Exception as src_error:
_LOGGER.exception("Unexpected error during intent recognition") _LOGGER.exception("Unexpected error during intent recognition")
@@ -1372,6 +1337,40 @@ class PipelineRun:
return (speech, can_acknowledge) return (speech, can_acknowledge)
def _can_acknowledge_response(
self, intent_response: intent.IntentResponse, device_id: str | None
) -> bool:
"""Return true if all targeted entities were in the same area as the device."""
if (
(intent_response.response_type != intent.IntentResponseType.ACTION_DONE)
or (not intent_response.matched_states)
or (not device_id)
):
return False
device_registry = dr.async_get(self.hass)
if (not (device := device_registry.async_get(device_id))) or (
not device.area_id
):
return False
entity_registry = er.async_get(self.hass)
for state in intent_response.matched_states:
entity = entity_registry.async_get(state.entity_id)
if (
(not entity)
or (entity.area_id and (entity.area_id != device.area_id))
or (
entity.device_id
and (entity_device := device_registry.async_get(entity.device_id))
and entity_device.area_id != device.area_id
)
):
return False
return True
async def prepare_text_to_speech(self) -> None: async def prepare_text_to_speech(self) -> None:
"""Prepare text-to-speech.""" """Prepare text-to-speech."""
# pipeline.tts_engine can't be None or this function is not called # pipeline.tts_engine can't be None or this function is not called
@@ -1455,7 +1454,7 @@ class PipelineRun:
media = await media_source.async_resolve_media(self.hass, media_id, None) media = await media_source.async_resolve_media(self.hass, media_id, None)
media_id = media.url media_id = media.url
else: else:
media_id = network.get_url(self.hass) + media_id media_id = media_player.async_process_play_media_url(self.hass, media_id)
tts_output = {"url": media_id} tts_output = {"url": media_id}

View File

@@ -28,6 +28,7 @@ PATHS_WITHOUT_AUTH = (
"/api/tts_proxy/", "/api/tts_proxy/",
"/api/esphome/ffmpeg_proxy/", "/api/esphome/ffmpeg_proxy/",
"/api/assist_satellite/static/", "/api/assist_satellite/static/",
"/api/assist_pipeline/static/",
) )