Compare commits

...

3 Commits

Author SHA1 Message Date
Marcel van der Veldt
6113288662 Address Copilot review: robust cleanup on cancellation
Ensure pipeline cleanup runs to completion even when execute() is
cancelled mid-run (WebSocket unsubscribe, timeout). If cancellation
hits one cleanup step, the remaining steps still run and the
CancelledError is re-raised at the end.

Also correct the queue size comment to match the actual 10ms chunks
used by the pipeline (256 * 10ms = ~2.6s, not ~4s).
2026-04-15 14:13:28 +02:00
Marcel van der Veldt
e0b1e99211 Add coverage for _close_async_generators and document queue bound 2026-04-15 12:53:55 +02:00
Marcel van der Veldt
a3a2557259 Clean up async generator lifecycle in assist pipeline
Close the STT audio stream async generators (process_enhance_audio /
process_volume_only / buffer_then_audio_stream) in PipelineInput.execute()
finally block so buffered audio chunks and the audio enhancer VAD state
are released promptly on early exit instead of waiting on garbage
collection.

Also bound the WebSocket audio queue to 256 frames (~4s of 16kHz mono)
and drop frames on overflow rather than growing without limit. Send a
stop sentinel on disconnect to unblock the stt_stream generator.
2026-04-14 21:41:25 +02:00
3 changed files with 225 additions and 6 deletions

View File

@@ -1631,6 +1631,29 @@ def _pipeline_debug_recording_thread_proc(
wav_writer.close()
async def _close_async_generators(
*generators: AsyncIterable[Any] | None,
) -> None:
"""Close async generators, suppressing non-cancellation errors.
If ``aclose()`` on one generator is cancelled, the others are still
attempted; the cancellation is re-raised once all generators have
been processed.
"""
cancelled_exc: asyncio.CancelledError | None = None
for gen in generators:
aclose = getattr(gen, "aclose", None)
if aclose is not None:
try:
await aclose()
except asyncio.CancelledError as exc:
cancelled_exc = exc
except Exception: # noqa: BLE001
pass
if cancelled_exc is not None:
raise cancelled_exc
@dataclass(kw_only=True)
class PipelineInput:
"""Input to a pipeline run."""
@@ -1680,12 +1703,16 @@ class PipelineInput:
)
current_stage: PipelineStage | None = self.run.start_stage
# Track async generators so they can be closed on early exit
# (validation error, no wake word, cancellation, etc.).
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
stt_input_stream: AsyncIterable[EnhancedAudioChunk] | None = None
try:
if validation_error is not None:
raise validation_error
stt_audio_buffer: list[EnhancedAudioChunk] = []
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
if self.stt_stream is not None:
if self.run.audio_settings.needs_processor:
@@ -1800,9 +1827,44 @@ class PipelineInput:
)
)
finally:
# Always end the run since it needs to shut down the debug recording
# thread, etc.
await self._cleanup(stt_input_stream, stt_processed_stream)
async def _cleanup(
self,
stt_input_stream: AsyncIterable[EnhancedAudioChunk] | None,
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None,
) -> None:
"""Release pipeline resources.
Close the STT audio stream async generators so buffered audio
chunks and the audio enhancer's VAD state are released promptly
instead of waiting on garbage collection (especially slow on
Python 3.14+). Close the wrapper first, then the upstream; skip
if both refer to the same object to avoid double-close.
Catch CancelledError around each cleanup step so a cancelled
pipeline (WebSocket unsubscribe, timeout) still runs the full
cleanup chain — otherwise cancellation reintroduces the very
leaks this code is trying to prevent. Re-raise at the end.
"""
cancelled_exc: asyncio.CancelledError | None = None
try:
await _close_async_generators(
None if stt_input_stream is stt_processed_stream else stt_input_stream,
stt_processed_stream,
)
except asyncio.CancelledError as exc:
cancelled_exc = exc
try:
# Always end the run since it needs to shut down the debug
# recording thread, etc.
await self.run.end()
except asyncio.CancelledError as exc:
cancelled_exc = cancelled_exc or exc
if cancelled_exc is not None:
raise cancelled_exc
async def validate(self) -> None:
"""Validate pipeline input against start stage."""

View File

@@ -155,7 +155,10 @@ async def websocket_run(
if start_stage in (PipelineStage.WAKE_WORD, PipelineStage.STT):
# Audio pipeline that will receive audio as binary websocket messages
msg_input = msg["input"]
audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
# ~2.6s of 16kHz mono audio at 10ms chunks — enough to absorb
# brief stalls but bounded so a stalled consumer can't grow
# memory unboundedly.
audio_queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=256)
incoming_sample_rate = msg_input["sample_rate"]
wake_word_phrase: str | None = None
@@ -188,8 +191,11 @@ async def websocket_run(
_connection: websocket_api.ActiveConnection,
data: bytes,
) -> None:
# Forward to STT audio stream
audio_queue.put_nowait(data)
# Forward to STT audio stream.
# Drop frames if the pipeline can't keep up rather than
# growing the queue without bound.
with contextlib.suppress(asyncio.QueueFull):
audio_queue.put_nowait(data)
handler_id, unregister_handler = connection.async_register_binary_handler(
handle_binary
@@ -273,6 +279,20 @@ async def websocket_run(
# Unregister binary handler
unregister_handler()
# Send stop signal to unblock the stt_stream generator.
# Empty bytes is falsy and causes the ``while chunk :=``
# loop to exit cleanly. If the bounded queue is full,
# discard queued audio until there is room for the stop
# sentinel so the stream can always exit.
while True:
try:
audio_queue.put_nowait(b"")
except asyncio.QueueFull:
with contextlib.suppress(asyncio.QueueEmpty):
audio_queue.get_nowait()
else:
break
@callback
@websocket_api.require_admin

View File

@@ -1,5 +1,6 @@
"""Websocket tests for Voice Assistant integration."""
import asyncio
from collections.abc import AsyncGenerator, Generator
from pathlib import Path
from typing import Any
@@ -35,6 +36,7 @@ from homeassistant.components.assist_pipeline.pipeline import (
PipelineStorageCollection,
PipelineStore,
_async_local_fallback_intent_filter,
_close_async_generators,
async_create_default_pipeline,
async_get_pipeline,
async_get_pipelines,
@@ -2153,3 +2155,138 @@ async def test_acknowledge_other_agents(
text_to_speech.assert_not_called()
async_converse.assert_called_once()
get_all_targets_in_satellite_area.assert_not_called()
async def test_close_async_generators_closes_generators() -> None:
"""Test the _close_async_generators helper closes every generator."""
closed: list[str] = []
async def make_gen(name: str) -> AsyncGenerator[bytes]:
try:
yield b""
finally:
closed.append(name)
gen_a = make_gen("a")
gen_b = make_gen("b")
# Start them so there is something to close.
await gen_a.__anext__()
await gen_b.__anext__()
await _close_async_generators(gen_a, gen_b)
assert closed == ["a", "b"]
async def test_close_async_generators_handles_none() -> None:
"""Test the helper skips None and non-generator objects."""
# Should not raise on None or objects without aclose.
await _close_async_generators(None, "not a generator", None) # type: ignore[arg-type]
async def test_close_async_generators_suppresses_errors() -> None:
"""Test the helper suppresses errors raised during aclose()."""
async def bad_gen() -> AsyncGenerator[bytes]:
try:
yield b""
finally:
raise RuntimeError("boom")
gen = bad_gen()
await gen.__anext__()
# Must not propagate the RuntimeError from the generator's finally.
await _close_async_generators(gen)
async def test_close_async_generators_closes_all_on_cancellation() -> None:
"""Test all generators get a chance to close even on cancellation.
Regression guard for the leak scenario: if one generator's aclose()
raises CancelledError, the remaining generators must still be
closed so no audio buffers or VAD state are orphaned.
"""
closed: list[str] = []
async def cancel_gen() -> AsyncGenerator[bytes]:
try:
yield b""
finally:
closed.append("cancel")
raise asyncio.CancelledError
async def normal_gen() -> AsyncGenerator[bytes]:
try:
yield b""
finally:
closed.append("normal")
gen_a = cancel_gen()
gen_b = normal_gen()
await gen_a.__anext__()
await gen_b.__anext__()
with pytest.raises(asyncio.CancelledError):
await _close_async_generators(gen_a, gen_b)
# Both generators must have been attempted, not just the first.
assert closed == ["cancel", "normal"]
async def test_pipeline_execute_closes_stt_generators(
hass: HomeAssistant,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
) -> None:
"""Test that PipelineInput.execute closes the STT audio generators.
Regression coverage for a leak where early exits of the pipeline (here:
no wake word detected) left the upstream audio generator un-closed,
keeping audio buffers and the audio enhancer's VAD state alive.
"""
closed = asyncio.Event()
async def audio_data() -> AsyncGenerator[bytes]:
try:
yield make_10ms_chunk(b"silence!")
yield b""
finally:
closed.set()
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
events: list[assist_pipeline.PipelineEvent] = []
pipeline_input = assist_pipeline.pipeline.PipelineInput(
session=mock_chat_session,
device_id=None,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output=None,
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
),
)
await pipeline_input.validate()
await pipeline_input.execute()
# Pipeline aborted (no wake word) — generator must have been closed.
assert closed.is_set()