Support streaming TTS in wyoming (#147392)

* Support streaming TTS in wyoming

* Add test

* Refactor to avoid repeated task creation

* Manually manage client lifecycle
This commit is contained in:
Michael Hansen
2025-06-24 12:04:40 -05:00
committed by GitHub
parent 3dc8676b99
commit cefc8822b6
5 changed files with 242 additions and 7 deletions

View File

@@ -1,13 +1,21 @@
"""Support for Wyoming text-to-speech services."""
from collections import defaultdict
from collections.abc import AsyncGenerator
import io
import logging
import wave
from wyoming.audio import AudioChunk, AudioStop
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient
from wyoming.tts import Synthesize, SynthesizeVoice
from wyoming.tts import (
Synthesize,
SynthesizeChunk,
SynthesizeStart,
SynthesizeStop,
SynthesizeStopped,
SynthesizeVoice,
)
from homeassistant.components import tts
from homeassistant.config_entries import ConfigEntry
@@ -45,6 +53,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
service: WyomingService,
) -> None:
"""Set up provider."""
self.config_entry = config_entry
self.service = service
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
@@ -150,3 +159,98 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
return (None, None)
return ("wav", data)
def async_supports_streaming_input(self) -> bool:
"""Return if the TTS engine supports streaming input."""
return self._tts_service.supports_synthesize_streaming
async def async_stream_tts_audio(
self, request: tts.TTSAudioRequest
) -> tts.TTSAudioResponse:
"""Generate speech from an incoming message."""
voice_name: str | None = request.options.get(tts.ATTR_VOICE)
voice_speaker: str | None = request.options.get(ATTR_SPEAKER)
voice: SynthesizeVoice | None = None
if voice_name is not None:
voice = SynthesizeVoice(name=voice_name, speaker=voice_speaker)
client = AsyncTcpClient(self.service.host, self.service.port)
await client.connect()
# Stream text chunks to client
self.config_entry.async_create_background_task(
self.hass,
self._write_tts_message(request.message_gen, client, voice),
"wyoming tts write",
)
async def data_gen():
# Stream audio bytes from client
try:
async for data_chunk in self._read_tts_audio(client):
yield data_chunk
finally:
await client.disconnect()
return tts.TTSAudioResponse("wav", data_gen())
async def _write_tts_message(
self,
message_gen: AsyncGenerator[str],
client: AsyncTcpClient,
voice: SynthesizeVoice | None,
) -> None:
"""Write text chunks to the client."""
try:
# Start stream
await client.write_event(SynthesizeStart(voice=voice).event())
# Accumulate entire message for synthesize event.
message = ""
async for message_chunk in message_gen:
message += message_chunk
await client.write_event(SynthesizeChunk(text=message_chunk).event())
# Send entire message for backwards compatibility
await client.write_event(Synthesize(text=message, voice=voice).event())
# End stream
await client.write_event(SynthesizeStop().event())
except (OSError, WyomingError):
# Disconnected
_LOGGER.warning("Unexpected disconnection from TTS client")
async def _read_tts_audio(self, client: AsyncTcpClient) -> AsyncGenerator[bytes]:
"""Read audio events from the client and yield WAV audio chunks.
The WAV header is sent first with a frame count of 0 to indicate that
we're streaming and don't know the number of frames ahead of time.
"""
wav_header_sent = False
try:
while event := await client.read_event():
if wav_header_sent and AudioChunk.is_type(event.type):
# PCM audio
yield AudioChunk.from_event(event).audio
elif (not wav_header_sent) and AudioStart.is_type(event.type):
# WAV header with nframes = 0 for streaming
audio_start = AudioStart.from_event(event)
with io.BytesIO() as wav_io:
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
with wav_file:
wav_file.setframerate(audio_start.rate)
wav_file.setsampwidth(audio_start.width)
wav_file.setnchannels(audio_start.channels)
wav_io.seek(0)
yield wav_io.getvalue()
wav_header_sent = True
elif SynthesizeStopped.is_type(event.type):
# All TTS audio has been received
break
except (OSError, WyomingError):
# Disconnected
_LOGGER.warning("Unexpected disconnection from TTS client")