Fix Wyoming satellite memory leak on disconnect (#168152)

This commit is contained in:
Marcel van der Veldt
2026-04-14 17:37:36 +02:00
committed by GitHub
parent 939412717f
commit 073d22d046
2 changed files with 356 additions and 18 deletions

View File

@@ -181,7 +181,19 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Set state based on pipeline stage."""
assert self._client is not None
if event.type == assist_pipeline.PipelineEventType.RUN_END:
# Pipeline run is complete — always update bookkeeping state
# even after a disconnect so follow-up reconnects don't retain
# stale _is_pipeline_running / _pipeline_ended_event state.
self._is_pipeline_running = False
self._pipeline_ended_event.set()
self.device.set_is_active(False)
self._tts_stream_token = None
self._is_tts_streaming = False
if self._client is None:
# Satellite disconnected, don't try to write to the client
return
if event.type == assist_pipeline.PipelineEventType.RUN_START:
if event.data and (tts_output := event.data["tts_output"]):
@@ -190,13 +202,6 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
# can start streaming TTS before the TTS_END event.
self._tts_stream_token = tts_output["token"]
self._is_tts_streaming = False
elif event.type == assist_pipeline.PipelineEventType.RUN_END:
# Pipeline run is complete
self._is_pipeline_running = False
self._pipeline_ended_event.set()
self.device.set_is_active(False)
self._tts_stream_token = None
self._is_tts_streaming = False
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
self.config_entry.async_create_background_task(
self.hass,
@@ -321,7 +326,8 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
Should block until the announcement is done playing.
"""
assert self._client is not None
if self._client is None:
raise ConnectionError("Satellite is not connected")
if self._ffmpeg_manager is None:
self._ffmpeg_manager = ffmpeg.get_ffmpeg_manager(self.hass)
@@ -441,6 +447,11 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
# Stop any existing pipeline
self._audio_queue.put_nowait(None)
# Cancel any pipeline still running so its background
# tasks and audio buffers can be released instead of
# being orphaned across the reconnect.
await self._cancel_running_pipeline()
# Ensure sensor is off (before restart)
self.device.set_is_active(False)
@@ -449,6 +460,9 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
finally:
unregister_timer_handler()
# Cancel any pipeline still running on final teardown.
await self._cancel_running_pipeline()
# Ensure sensor is off (before stop)
self.device.set_is_active(False)
@@ -699,10 +713,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
async def _send_delayed_ping(self) -> None:
"""Send ping to satellite after a delay."""
assert self._client is not None
try:
await asyncio.sleep(_PING_SEND_DELAY)
if self._client is None:
return
await self._client.write_event(Ping().event())
except ConnectionError:
pass # handled with timeout
@@ -728,7 +742,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
async def _stream_tts(self, tts_result: tts.ResultStream) -> None:
"""Stream TTS WAV audio to satellite in chunks."""
assert self._client is not None
client = self._client
if client is None:
# Satellite disconnected, cannot stream
return
if tts_result.extension != "wav":
raise ValueError(
@@ -760,7 +777,7 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
sample_rate, sample_width, sample_channels, data_chunk = (
audio_info
)
await self._client.write_event(
await client.write_event(
AudioStart(
rate=sample_rate,
width=sample_width,
@@ -794,12 +811,12 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
timestamp=timestamp,
)
await self._client.write_event(audio_chunk.event())
await client.write_event(audio_chunk.event())
timestamp += audio_chunk.milliseconds
total_seconds += audio_chunk.seconds
data_chunk_idx += _AUDIO_CHUNK_BYTES
await self._client.write_event(AudioStop(timestamp=timestamp).event())
await client.write_event(AudioStop(timestamp=timestamp).event())
_LOGGER.debug("TTS streaming complete")
finally:
send_duration = time.monotonic() - start_time
@@ -840,7 +857,9 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
self, event_type: intent.TimerEventType, timer: intent.TimerInfo
) -> None:
"""Forward timer events to satellite."""
assert self._client is not None
if self._client is None:
# Satellite disconnected, drop timer event
return
_LOGGER.debug("Timer event: type=%s, info=%s", event_type, timer)
event: Event | None = None

View File

@@ -7,9 +7,10 @@ from collections.abc import Callable
import io
import tempfile
from typing import Any
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import wave
import pytest
from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.error import Error
@@ -24,7 +25,7 @@ from wyoming.tts import Synthesize
from wyoming.vad import VoiceStarted, VoiceStopped
from wyoming.wake import Detect, Detection
from homeassistant.components import assist_pipeline, assist_satellite, wyoming
from homeassistant.components import assist_pipeline, assist_satellite, intent, wyoming
from homeassistant.components.wyoming.assist_satellite import WyomingAssistSatellite
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.const import STATE_ON
@@ -655,6 +656,324 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None
assert not device.is_active
async def test_satellite_disconnect_cancels_running_pipeline(
hass: HomeAssistant,
) -> None:
"""Test that a satellite disconnect cancels the in-flight pipeline task.
Regression test for a memory leak introduced in 2026.4.0 where a Wyoming
client disconnection left the pipeline task running in the background, so
every lingering pipeline event tried to write to a now-``None`` client and
accumulated background tasks until the process was OOM-killed.
"""
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
] # no audio chunks after RunPipeline, peer goes away
pipeline_started = asyncio.Event()
pipeline_cancelled = asyncio.Event()
on_restart_event = asyncio.Event()
on_stopped_event = asyncio.Event()
async def _long_running_pipeline(*args: Any, **kwargs: Any) -> None:
pipeline_started.set()
try:
# Keep the pipeline alive until it gets cancelled by the satellite.
await asyncio.Event().wait()
except asyncio.CancelledError:
pipeline_cancelled.set()
raise
async def on_restart(self):
self.stop_satellite()
on_restart_event.set()
async def on_stopped(self):
on_stopped_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
MockAsyncTcpClient(events),
),
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_long_running_pipeline,
),
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
on_restart,
),
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_stopped",
on_stopped,
),
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
# Pipeline starts, then the peer disconnects, satellite should
# cancel the pipeline before restarting the connection.
await pipeline_started.wait()
await pipeline_cancelled.wait()
await on_restart_event.wait()
await on_stopped_event.wait()
async def test_on_pipeline_event_ignores_disconnected_client(
hass: HomeAssistant,
) -> None:
"""Test that ``on_pipeline_event`` is a no-op after the client disconnected.
Previously this path hit ``assert self._client is not None``, which raised
``AssertionError`` once per event while the pipeline kept running after a
disconnect, contributing to the memory leak in 2026.4.0.
"""
events: list[Event] = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
pipeline_event = asyncio.Event()
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
pipeline_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await pipeline_event.wait()
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
# event_callback is the base class's bound _internal_on_pipeline_event,
# so we can reach the satellite entity from there.
satellite: WyomingAssistSatellite = event_callback.__self__
# Simulate the disconnect race: the pipeline is still firing events
# but the TCP client has already been torn down.
satellite._client = None
# Must not raise, must not spawn a background write task.
for event_type in (
assist_pipeline.PipelineEventType.WAKE_WORD_START,
assist_pipeline.PipelineEventType.STT_START,
assist_pipeline.PipelineEventType.STT_END,
assist_pipeline.PipelineEventType.TTS_START,
assist_pipeline.PipelineEventType.ERROR,
):
event_callback(
assist_pipeline.PipelineEvent(
event_type,
{
"metadata": {"language": "en"},
"stt_output": {"text": "ignored"},
"tts_input": "ignored",
"code": "err",
"message": "ignored",
"timestamp": 0,
},
)
)
# RUN_END must still update bookkeeping even with no client.
satellite._is_pipeline_running = True
satellite._pipeline_ended_event.clear()
event_callback(
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END, {})
)
assert not satellite._is_pipeline_running
assert satellite._pipeline_ended_event.is_set()
# Flush any stray background tasks before asserting on side effects.
await hass.async_block_till_done()
# If the guard did not hold, the mock client would have observed
# ``Detect``, ``Transcribe``, ``Transcript``, ``Synthesize`` and
# ``Error`` events.
assert not mock_client.detect_event.is_set()
assert not mock_client.transcribe_event.is_set()
assert not mock_client.transcript_event.is_set()
assert not mock_client.synthesize_event.is_set()
assert not mock_client.error_event.is_set()
async def test_announce_raises_when_client_disconnected(
hass: HomeAssistant,
) -> None:
"""Test that async_announce raises ConnectionError when client is None."""
events: list[Event] = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
pipeline_event = asyncio.Event()
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
pipeline_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await pipeline_event.wait()
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
satellite: WyomingAssistSatellite = event_callback.__self__
satellite._client = None
with pytest.raises(ConnectionError, match="not connected"):
await satellite.async_announce(
assist_satellite.AssistSatelliteAnnouncement(
message="test",
media_id="test",
original_media_id="test",
tts_token=None,
media_id_source="tts",
)
)
async def test_stream_tts_noop_when_client_disconnected(
hass: HomeAssistant,
) -> None:
"""Test that _stream_tts returns immediately when client is None."""
events: list[Event] = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
pipeline_event = asyncio.Event()
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
pipeline_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await pipeline_event.wait()
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
satellite: WyomingAssistSatellite = event_callback.__self__
satellite._client = None
# Should return immediately without touching the stream object
await satellite._stream_tts(MagicMock())
async def test_handle_timer_noop_when_client_disconnected(
hass: HomeAssistant,
) -> None:
"""Test that _handle_timer returns immediately when client is None."""
events: list[Event] = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
pipeline_event = asyncio.Event()
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
pipeline_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await pipeline_event.wait()
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
satellite: WyomingAssistSatellite = event_callback.__self__
satellite._client = None
# Should not raise
satellite._handle_timer(
intent.TimerEventType.STARTED,
intent.TimerInfo(
id="test-timer",
name="test",
seconds=30,
device_id=None,
start_hours=0,
start_minutes=0,
start_seconds=30,
created_at=0,
updated_at=0,
language="en",
),
)
async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
"""Test satellite error occurring during pipeline run."""
events = [