Support streaming TTS in wyoming (#147392)

* Support streaming TTS in wyoming

* Add test

* Refactor to avoid repeated task creation

* Manually manage client lifecycle
This commit is contained in:
Michael Hansen
2025-06-24 12:04:40 -05:00
committed by GitHub
parent 3dc8676b99
commit cefc8822b6
5 changed files with 242 additions and 7 deletions

View File

@ -69,6 +69,29 @@ TTS_INFO = Info(
)
]
)
TTS_STREAMING_INFO = Info(
tts=[
TtsProgram(
name="Test Streaming TTS",
description="Test Streaming TTS",
installed=True,
attribution=TEST_ATTR,
voices=[
TtsVoice(
name="Test Voice",
description="Test Voice",
installed=True,
attribution=TEST_ATTR,
languages=["en-US"],
speakers=[TtsVoiceSpeaker(name="Test Speaker")],
version=None,
)
],
version=None,
supports_synthesize_streaming=True,
)
]
)
WAKE_WORD_INFO = Info(
wake=[
WakeProgram(
@ -155,9 +178,15 @@ class MockAsyncTcpClient:
self.port: int | None = None
self.written: list[Event] = []
self.responses = responses
self.is_connected: bool | None = None
async def connect(self) -> None:
"""Connect."""
self.is_connected = True
async def disconnect(self) -> None:
"""Disconnect."""
self.is_connected = False
async def write_event(self, event: Event):
"""Send."""

View File

@ -19,6 +19,7 @@ from . import (
SATELLITE_INFO,
STT_INFO,
TTS_INFO,
TTS_STREAMING_INFO,
WAKE_WORD_INFO,
)
@ -148,6 +149,20 @@ async def init_wyoming_tts(
return tts_config_entry
@pytest.fixture
async def init_wyoming_streaming_tts(
hass: HomeAssistant, tts_config_entry: ConfigEntry
) -> ConfigEntry:
"""Initialize Wyoming streaming TTS."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=TTS_STREAMING_INFO,
):
await hass.config_entries.async_setup(tts_config_entry.entry_id)
return tts_config_entry
@pytest.fixture
async def init_wyoming_wake_word(
hass: HomeAssistant, wake_word_config_entry: ConfigEntry

View File

@ -32,6 +32,43 @@
}),
])
# ---
# name: test_get_tts_audio_streaming
list([
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello ',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({
'data': dict({
'text': 'Word.',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({
'data': dict({
'text': 'Hello Word.',
}),
'payload': None,
'type': 'synthesize',
}),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
])
# ---
# name: test_voice_speaker
list([
dict({

View File

@ -8,7 +8,8 @@ import wave
import pytest
from syrupy.assertion import SnapshotAssertion
from wyoming.audio import AudioChunk, AudioStop
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.tts import SynthesizeStopped
from homeassistant.components import tts, wyoming
from homeassistant.core import HomeAssistant
@ -43,11 +44,11 @@ async def test_get_tts_audio(
hass: HomeAssistant, init_wyoming_tts, snapshot: SnapshotAssertion
) -> None:
"""Test get audio."""
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
assert entity is not None
assert not entity.async_supports_streaming_input()
audio = bytes(100)
audio_events = [
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
]
# Verify audio
audio_events = [
@ -215,3 +216,52 @@ async def test_voice_speaker(
),
)
assert mock_client.written == snapshot
async def test_get_tts_audio_streaming(
hass: HomeAssistant, init_wyoming_streaming_tts, snapshot: SnapshotAssertion
) -> None:
"""Test get audio with streaming."""
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_streaming_tts")
assert entity is not None
assert entity.async_supports_streaming_input()
audio = bytes(100)
# Verify audio
audio_events = [
AudioStart(rate=16000, width=2, channels=1).event(),
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
AudioStop().event(),
SynthesizeStopped().event(),
]
async def message_gen():
yield "Hello "
yield "Word."
with patch(
"homeassistant.components.wyoming.tts.AsyncTcpClient",
MockAsyncTcpClient(audio_events),
) as mock_client:
stream = tts.async_create_stream(
hass,
"tts.test_streaming_tts",
"en-US",
options={tts.ATTR_PREFERRED_FORMAT: "wav"},
)
stream.async_set_message_stream(message_gen())
data = b"".join([chunk async for chunk in stream.async_stream_result()])
# Ensure client was disconnected properly
assert mock_client.is_connected is False
assert data is not None
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
assert wav_file.getframerate() == 16000
assert wav_file.getsampwidth() == 2
assert wav_file.getnchannels() == 1
assert wav_file.getnframes() == 0 # streaming
assert data[44:] == audio # WAV header is 44 bytes
assert mock_client.written == snapshot