mirror of
https://github.com/home-assistant/core.git
synced 2026-04-19 07:59:14 +02:00
Fix Wyoming satellite memory leak on disconnect (#168152)
This commit is contained in:
committed by
GitHub
parent
939412717f
commit
073d22d046
@@ -181,7 +181,19 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
|
||||
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||
"""Set state based on pipeline stage."""
|
||||
assert self._client is not None
|
||||
if event.type == assist_pipeline.PipelineEventType.RUN_END:
|
||||
# Pipeline run is complete — always update bookkeeping state
|
||||
# even after a disconnect so follow-up reconnects don't retain
|
||||
# stale _is_pipeline_running / _pipeline_ended_event state.
|
||||
self._is_pipeline_running = False
|
||||
self._pipeline_ended_event.set()
|
||||
self.device.set_is_active(False)
|
||||
self._tts_stream_token = None
|
||||
self._is_tts_streaming = False
|
||||
|
||||
if self._client is None:
|
||||
# Satellite disconnected, don't try to write to the client
|
||||
return
|
||||
|
||||
if event.type == assist_pipeline.PipelineEventType.RUN_START:
|
||||
if event.data and (tts_output := event.data["tts_output"]):
|
||||
@@ -190,13 +202,6 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
# can start streaming TTS before the TTS_END event.
|
||||
self._tts_stream_token = tts_output["token"]
|
||||
self._is_tts_streaming = False
|
||||
elif event.type == assist_pipeline.PipelineEventType.RUN_END:
|
||||
# Pipeline run is complete
|
||||
self._is_pipeline_running = False
|
||||
self._pipeline_ended_event.set()
|
||||
self.device.set_is_active(False)
|
||||
self._tts_stream_token = None
|
||||
self._is_tts_streaming = False
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
@@ -321,7 +326,8 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
|
||||
Should block until the announcement is done playing.
|
||||
"""
|
||||
assert self._client is not None
|
||||
if self._client is None:
|
||||
raise ConnectionError("Satellite is not connected")
|
||||
|
||||
if self._ffmpeg_manager is None:
|
||||
self._ffmpeg_manager = ffmpeg.get_ffmpeg_manager(self.hass)
|
||||
@@ -441,6 +447,11 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
# Stop any existing pipeline
|
||||
self._audio_queue.put_nowait(None)
|
||||
|
||||
# Cancel any pipeline still running so its background
|
||||
# tasks and audio buffers can be released instead of
|
||||
# being orphaned across the reconnect.
|
||||
await self._cancel_running_pipeline()
|
||||
|
||||
# Ensure sensor is off (before restart)
|
||||
self.device.set_is_active(False)
|
||||
|
||||
@@ -449,6 +460,9 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
finally:
|
||||
unregister_timer_handler()
|
||||
|
||||
# Cancel any pipeline still running on final teardown.
|
||||
await self._cancel_running_pipeline()
|
||||
|
||||
# Ensure sensor is off (before stop)
|
||||
self.device.set_is_active(False)
|
||||
|
||||
@@ -699,10 +713,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
|
||||
async def _send_delayed_ping(self) -> None:
|
||||
"""Send ping to satellite after a delay."""
|
||||
assert self._client is not None
|
||||
|
||||
try:
|
||||
await asyncio.sleep(_PING_SEND_DELAY)
|
||||
if self._client is None:
|
||||
return
|
||||
await self._client.write_event(Ping().event())
|
||||
except ConnectionError:
|
||||
pass # handled with timeout
|
||||
@@ -728,7 +742,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
|
||||
async def _stream_tts(self, tts_result: tts.ResultStream) -> None:
|
||||
"""Stream TTS WAV audio to satellite in chunks."""
|
||||
assert self._client is not None
|
||||
client = self._client
|
||||
if client is None:
|
||||
# Satellite disconnected, cannot stream
|
||||
return
|
||||
|
||||
if tts_result.extension != "wav":
|
||||
raise ValueError(
|
||||
@@ -760,7 +777,7 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
sample_rate, sample_width, sample_channels, data_chunk = (
|
||||
audio_info
|
||||
)
|
||||
await self._client.write_event(
|
||||
await client.write_event(
|
||||
AudioStart(
|
||||
rate=sample_rate,
|
||||
width=sample_width,
|
||||
@@ -794,12 +811,12 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
await self._client.write_event(audio_chunk.event())
|
||||
await client.write_event(audio_chunk.event())
|
||||
timestamp += audio_chunk.milliseconds
|
||||
total_seconds += audio_chunk.seconds
|
||||
data_chunk_idx += _AUDIO_CHUNK_BYTES
|
||||
|
||||
await self._client.write_event(AudioStop(timestamp=timestamp).event())
|
||||
await client.write_event(AudioStop(timestamp=timestamp).event())
|
||||
_LOGGER.debug("TTS streaming complete")
|
||||
finally:
|
||||
send_duration = time.monotonic() - start_time
|
||||
@@ -840,7 +857,9 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
self, event_type: intent.TimerEventType, timer: intent.TimerInfo
|
||||
) -> None:
|
||||
"""Forward timer events to satellite."""
|
||||
assert self._client is not None
|
||||
if self._client is None:
|
||||
# Satellite disconnected, drop timer event
|
||||
return
|
||||
|
||||
_LOGGER.debug("Timer event: type=%s, info=%s", event_type, timer)
|
||||
event: Event | None = None
|
||||
|
||||
@@ -7,9 +7,10 @@ from collections.abc import Callable
|
||||
import io
|
||||
import tempfile
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
from wyoming.asr import Transcribe, Transcript
|
||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.error import Error
|
||||
@@ -24,7 +25,7 @@ from wyoming.tts import Synthesize
|
||||
from wyoming.vad import VoiceStarted, VoiceStopped
|
||||
from wyoming.wake import Detect, Detection
|
||||
|
||||
from homeassistant.components import assist_pipeline, assist_satellite, wyoming
|
||||
from homeassistant.components import assist_pipeline, assist_satellite, intent, wyoming
|
||||
from homeassistant.components.wyoming.assist_satellite import WyomingAssistSatellite
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.const import STATE_ON
|
||||
@@ -655,6 +656,324 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None
|
||||
assert not device.is_active
|
||||
|
||||
|
||||
async def test_satellite_disconnect_cancels_running_pipeline(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that a satellite disconnect cancels the in-flight pipeline task.
|
||||
|
||||
Regression test for a memory leak introduced in 2026.4.0 where a Wyoming
|
||||
client disconnection left the pipeline task running in the background, so
|
||||
every lingering pipeline event tried to write to a now-``None`` client and
|
||||
accumulated background tasks until the process was OOM-killed.
|
||||
"""
|
||||
events = [
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||
).event(),
|
||||
] # no audio chunks after RunPipeline, peer goes away
|
||||
|
||||
pipeline_started = asyncio.Event()
|
||||
pipeline_cancelled = asyncio.Event()
|
||||
on_restart_event = asyncio.Event()
|
||||
on_stopped_event = asyncio.Event()
|
||||
|
||||
async def _long_running_pipeline(*args: Any, **kwargs: Any) -> None:
|
||||
pipeline_started.set()
|
||||
try:
|
||||
# Keep the pipeline alive until it gets cancelled by the satellite.
|
||||
await asyncio.Event().wait()
|
||||
except asyncio.CancelledError:
|
||||
pipeline_cancelled.set()
|
||||
raise
|
||||
|
||||
async def on_restart(self):
|
||||
self.stop_satellite()
|
||||
on_restart_event.set()
|
||||
|
||||
async def on_stopped(self):
|
||||
on_stopped_event.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
MockAsyncTcpClient(events),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
wraps=_long_running_pipeline,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
|
||||
on_restart,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_stopped",
|
||||
on_stopped,
|
||||
),
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
# Pipeline starts, then the peer disconnects, satellite should
|
||||
# cancel the pipeline before restarting the connection.
|
||||
await pipeline_started.wait()
|
||||
await pipeline_cancelled.wait()
|
||||
await on_restart_event.wait()
|
||||
await on_stopped_event.wait()
|
||||
|
||||
|
||||
async def test_on_pipeline_event_ignores_disconnected_client(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that ``on_pipeline_event`` is a no-op after the client disconnected.
|
||||
|
||||
Previously this path hit ``assert self._client is not None``, which raised
|
||||
``AssertionError`` once per event while the pipeline kept running after a
|
||||
disconnect, contributing to the memory leak in 2026.4.0.
|
||||
"""
|
||||
events: list[Event] = [
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||
).event(),
|
||||
]
|
||||
|
||||
pipeline_event = asyncio.Event()
|
||||
|
||||
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
||||
pipeline_event.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline,
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await pipeline_event.wait()
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
|
||||
# event_callback is the base class's bound _internal_on_pipeline_event,
|
||||
# so we can reach the satellite entity from there.
|
||||
satellite: WyomingAssistSatellite = event_callback.__self__
|
||||
|
||||
# Simulate the disconnect race: the pipeline is still firing events
|
||||
# but the TCP client has already been torn down.
|
||||
satellite._client = None
|
||||
|
||||
# Must not raise, must not spawn a background write task.
|
||||
for event_type in (
|
||||
assist_pipeline.PipelineEventType.WAKE_WORD_START,
|
||||
assist_pipeline.PipelineEventType.STT_START,
|
||||
assist_pipeline.PipelineEventType.STT_END,
|
||||
assist_pipeline.PipelineEventType.TTS_START,
|
||||
assist_pipeline.PipelineEventType.ERROR,
|
||||
):
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
event_type,
|
||||
{
|
||||
"metadata": {"language": "en"},
|
||||
"stt_output": {"text": "ignored"},
|
||||
"tts_input": "ignored",
|
||||
"code": "err",
|
||||
"message": "ignored",
|
||||
"timestamp": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# RUN_END must still update bookkeeping even with no client.
|
||||
satellite._is_pipeline_running = True
|
||||
satellite._pipeline_ended_event.clear()
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END, {})
|
||||
)
|
||||
assert not satellite._is_pipeline_running
|
||||
assert satellite._pipeline_ended_event.is_set()
|
||||
|
||||
# Flush any stray background tasks before asserting on side effects.
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# If the guard did not hold, the mock client would have observed
|
||||
# ``Detect``, ``Transcribe``, ``Transcript``, ``Synthesize`` and
|
||||
# ``Error`` events.
|
||||
assert not mock_client.detect_event.is_set()
|
||||
assert not mock_client.transcribe_event.is_set()
|
||||
assert not mock_client.transcript_event.is_set()
|
||||
assert not mock_client.synthesize_event.is_set()
|
||||
assert not mock_client.error_event.is_set()
|
||||
|
||||
|
||||
async def test_announce_raises_when_client_disconnected(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that async_announce raises ConnectionError when client is None."""
|
||||
events: list[Event] = [
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||
).event(),
|
||||
]
|
||||
|
||||
pipeline_event = asyncio.Event()
|
||||
|
||||
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
||||
pipeline_event.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline,
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await pipeline_event.wait()
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
|
||||
satellite: WyomingAssistSatellite = event_callback.__self__
|
||||
satellite._client = None
|
||||
|
||||
with pytest.raises(ConnectionError, match="not connected"):
|
||||
await satellite.async_announce(
|
||||
assist_satellite.AssistSatelliteAnnouncement(
|
||||
message="test",
|
||||
media_id="test",
|
||||
original_media_id="test",
|
||||
tts_token=None,
|
||||
media_id_source="tts",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def test_stream_tts_noop_when_client_disconnected(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that _stream_tts returns immediately when client is None."""
|
||||
events: list[Event] = [
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||
).event(),
|
||||
]
|
||||
|
||||
pipeline_event = asyncio.Event()
|
||||
|
||||
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
||||
pipeline_event.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline,
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await pipeline_event.wait()
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
|
||||
satellite: WyomingAssistSatellite = event_callback.__self__
|
||||
satellite._client = None
|
||||
|
||||
# Should return immediately without touching the stream object
|
||||
await satellite._stream_tts(MagicMock())
|
||||
|
||||
|
||||
async def test_handle_timer_noop_when_client_disconnected(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that _handle_timer returns immediately when client is None."""
|
||||
events: list[Event] = [
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||
).event(),
|
||||
]
|
||||
|
||||
pipeline_event = asyncio.Event()
|
||||
|
||||
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
||||
pipeline_event.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline,
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await pipeline_event.wait()
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
|
||||
satellite: WyomingAssistSatellite = event_callback.__self__
|
||||
satellite._client = None
|
||||
|
||||
# Should not raise
|
||||
satellite._handle_timer(
|
||||
intent.TimerEventType.STARTED,
|
||||
intent.TimerInfo(
|
||||
id="test-timer",
|
||||
name="test",
|
||||
seconds=30,
|
||||
device_id=None,
|
||||
start_hours=0,
|
||||
start_minutes=0,
|
||||
start_seconds=30,
|
||||
created_at=0,
|
||||
updated_at=0,
|
||||
language="en",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
|
||||
"""Test satellite error occurring during pipeline run."""
|
||||
events = [
|
||||
|
||||
Reference in New Issue
Block a user