mirror of
https://github.com/home-assistant/core.git
synced 2026-04-20 08:29:39 +02:00
Compare commits
3 Commits
timer_add_
...
fix/assist
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6113288662 | ||
|
|
e0b1e99211 | ||
|
|
a3a2557259 |
@@ -1631,6 +1631,29 @@ def _pipeline_debug_recording_thread_proc(
|
||||
wav_writer.close()
|
||||
|
||||
|
||||
async def _close_async_generators(
|
||||
*generators: AsyncIterable[Any] | None,
|
||||
) -> None:
|
||||
"""Close async generators, suppressing non-cancellation errors.
|
||||
|
||||
If ``aclose()`` on one generator is cancelled, the others are still
|
||||
attempted; the cancellation is re-raised once all generators have
|
||||
been processed.
|
||||
"""
|
||||
cancelled_exc: asyncio.CancelledError | None = None
|
||||
for gen in generators:
|
||||
aclose = getattr(gen, "aclose", None)
|
||||
if aclose is not None:
|
||||
try:
|
||||
await aclose()
|
||||
except asyncio.CancelledError as exc:
|
||||
cancelled_exc = exc
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
if cancelled_exc is not None:
|
||||
raise cancelled_exc
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class PipelineInput:
|
||||
"""Input to a pipeline run."""
|
||||
@@ -1680,12 +1703,16 @@ class PipelineInput:
|
||||
)
|
||||
current_stage: PipelineStage | None = self.run.start_stage
|
||||
|
||||
# Track async generators so they can be closed on early exit
|
||||
# (validation error, no wake word, cancellation, etc.).
|
||||
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
|
||||
stt_input_stream: AsyncIterable[EnhancedAudioChunk] | None = None
|
||||
|
||||
try:
|
||||
if validation_error is not None:
|
||||
raise validation_error
|
||||
|
||||
stt_audio_buffer: list[EnhancedAudioChunk] = []
|
||||
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
|
||||
|
||||
if self.stt_stream is not None:
|
||||
if self.run.audio_settings.needs_processor:
|
||||
@@ -1800,9 +1827,44 @@ class PipelineInput:
|
||||
)
|
||||
)
|
||||
finally:
|
||||
# Always end the run since it needs to shut down the debug recording
|
||||
# thread, etc.
|
||||
await self._cleanup(stt_input_stream, stt_processed_stream)
|
||||
|
||||
async def _cleanup(
|
||||
self,
|
||||
stt_input_stream: AsyncIterable[EnhancedAudioChunk] | None,
|
||||
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None,
|
||||
) -> None:
|
||||
"""Release pipeline resources.
|
||||
|
||||
Close the STT audio stream async generators so buffered audio
|
||||
chunks and the audio enhancer's VAD state are released promptly
|
||||
instead of waiting on garbage collection (especially slow on
|
||||
Python 3.14+). Close the wrapper first, then the upstream; skip
|
||||
if both refer to the same object to avoid double-close.
|
||||
|
||||
Catch CancelledError around each cleanup step so a cancelled
|
||||
pipeline (WebSocket unsubscribe, timeout) still runs the full
|
||||
cleanup chain — otherwise cancellation reintroduces the very
|
||||
leaks this code is trying to prevent. Re-raise at the end.
|
||||
"""
|
||||
cancelled_exc: asyncio.CancelledError | None = None
|
||||
try:
|
||||
await _close_async_generators(
|
||||
None if stt_input_stream is stt_processed_stream else stt_input_stream,
|
||||
stt_processed_stream,
|
||||
)
|
||||
except asyncio.CancelledError as exc:
|
||||
cancelled_exc = exc
|
||||
|
||||
try:
|
||||
# Always end the run since it needs to shut down the debug
|
||||
# recording thread, etc.
|
||||
await self.run.end()
|
||||
except asyncio.CancelledError as exc:
|
||||
cancelled_exc = cancelled_exc or exc
|
||||
|
||||
if cancelled_exc is not None:
|
||||
raise cancelled_exc
|
||||
|
||||
async def validate(self) -> None:
|
||||
"""Validate pipeline input against start stage."""
|
||||
|
||||
@@ -155,7 +155,10 @@ async def websocket_run(
|
||||
if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
|
||||
# Audio pipeline that will receive audio as binary websocket messages
|
||||
msg_input = msg["input"]
|
||||
audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
# ~2.6s of 16kHz mono audio at 10ms chunks — enough to absorb
|
||||
# brief stalls but bounded so a stalled consumer can't grow
|
||||
# memory unboundedly.
|
||||
audio_queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=256)
|
||||
incoming_sample_rate = msg_input["sample_rate"]
|
||||
wake_word_phrase: str | None = None
|
||||
|
||||
@@ -188,8 +191,11 @@ async def websocket_run(
|
||||
_connection: websocket_api.ActiveConnection,
|
||||
data: bytes,
|
||||
) -> None:
|
||||
# Forward to STT audio stream
|
||||
audio_queue.put_nowait(data)
|
||||
# Forward to STT audio stream.
|
||||
# Drop frames if the pipeline can't keep up rather than
|
||||
# growing the queue without bound.
|
||||
with contextlib.suppress(asyncio.QueueFull):
|
||||
audio_queue.put_nowait(data)
|
||||
|
||||
handler_id, unregister_handler = connection.async_register_binary_handler(
|
||||
handle_binary
|
||||
@@ -273,6 +279,20 @@ async def websocket_run(
|
||||
# Unregister binary handler
|
||||
unregister_handler()
|
||||
|
||||
# Send stop signal to unblock the stt_stream generator.
|
||||
# Empty bytes is falsy and causes the ``while chunk :=``
|
||||
# loop to exit cleanly. If the bounded queue is full,
|
||||
# discard queued audio until there is room for the stop
|
||||
# sentinel so the stream can always exit.
|
||||
while True:
|
||||
try:
|
||||
audio_queue.put_nowait(b"")
|
||||
except asyncio.QueueFull:
|
||||
with contextlib.suppress(asyncio.QueueEmpty):
|
||||
audio_queue.get_nowait()
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_admin
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Websocket tests for Voice Assistant integration."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -35,6 +36,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
||||
PipelineStorageCollection,
|
||||
PipelineStore,
|
||||
_async_local_fallback_intent_filter,
|
||||
_close_async_generators,
|
||||
async_create_default_pipeline,
|
||||
async_get_pipeline,
|
||||
async_get_pipelines,
|
||||
@@ -2153,3 +2155,138 @@ async def test_acknowledge_other_agents(
|
||||
text_to_speech.assert_not_called()
|
||||
async_converse.assert_called_once()
|
||||
get_all_targets_in_satellite_area.assert_not_called()
|
||||
|
||||
|
||||
async def test_close_async_generators_closes_generators() -> None:
|
||||
"""Test the _close_async_generators helper closes every generator."""
|
||||
closed: list[str] = []
|
||||
|
||||
async def make_gen(name: str) -> AsyncGenerator[bytes]:
|
||||
try:
|
||||
yield b""
|
||||
finally:
|
||||
closed.append(name)
|
||||
|
||||
gen_a = make_gen("a")
|
||||
gen_b = make_gen("b")
|
||||
|
||||
# Start them so there is something to close.
|
||||
await gen_a.__anext__()
|
||||
await gen_b.__anext__()
|
||||
|
||||
await _close_async_generators(gen_a, gen_b)
|
||||
|
||||
assert closed == ["a", "b"]
|
||||
|
||||
|
||||
async def test_close_async_generators_handles_none() -> None:
|
||||
"""Test the helper skips None and non-generator objects."""
|
||||
# Should not raise on None or objects without aclose.
|
||||
await _close_async_generators(None, "not a generator", None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def test_close_async_generators_suppresses_errors() -> None:
|
||||
"""Test the helper suppresses errors raised during aclose()."""
|
||||
|
||||
async def bad_gen() -> AsyncGenerator[bytes]:
|
||||
try:
|
||||
yield b""
|
||||
finally:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
gen = bad_gen()
|
||||
await gen.__anext__()
|
||||
|
||||
# Must not propagate the RuntimeError from the generator's finally.
|
||||
await _close_async_generators(gen)
|
||||
|
||||
|
||||
async def test_close_async_generators_closes_all_on_cancellation() -> None:
|
||||
"""Test all generators get a chance to close even on cancellation.
|
||||
|
||||
Regression guard for the leak scenario: if one generator's aclose()
|
||||
raises CancelledError, the remaining generators must still be
|
||||
closed so no audio buffers or VAD state are orphaned.
|
||||
"""
|
||||
closed: list[str] = []
|
||||
|
||||
async def cancel_gen() -> AsyncGenerator[bytes]:
|
||||
try:
|
||||
yield b""
|
||||
finally:
|
||||
closed.append("cancel")
|
||||
raise asyncio.CancelledError
|
||||
|
||||
async def normal_gen() -> AsyncGenerator[bytes]:
|
||||
try:
|
||||
yield b""
|
||||
finally:
|
||||
closed.append("normal")
|
||||
|
||||
gen_a = cancel_gen()
|
||||
gen_b = normal_gen()
|
||||
await gen_a.__anext__()
|
||||
await gen_b.__anext__()
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await _close_async_generators(gen_a, gen_b)
|
||||
|
||||
# Both generators must have been attempted, not just the first.
|
||||
assert closed == ["cancel", "normal"]
|
||||
|
||||
|
||||
async def test_pipeline_execute_closes_stt_generators(
|
||||
hass: HomeAssistant,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
) -> None:
|
||||
"""Test that PipelineInput.execute closes the STT audio generators.
|
||||
|
||||
Regression coverage for a leak where early exits of the pipeline (here:
|
||||
no wake word detected) left the upstream audio generator un-closed,
|
||||
keeping audio buffers and the audio enhancer's VAD state alive.
|
||||
"""
|
||||
closed = asyncio.Event()
|
||||
|
||||
async def audio_data() -> AsyncGenerator[bytes]:
|
||||
try:
|
||||
yield make_10ms_chunk(b"silence!")
|
||||
yield b""
|
||||
finally:
|
||||
closed.set()
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=events.append,
|
||||
tts_audio_output=None,
|
||||
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
await pipeline_input.execute()
|
||||
|
||||
# Pipeline aborted (no wake word) — generator must have been closed.
|
||||
assert closed.is_set()
|
||||
|
||||
Reference in New Issue
Block a user