mirror of
https://github.com/home-assistant/core.git
synced 2025-08-02 20:25:07 +02:00
Support streaming TTS in wyoming (#147392)
* Support streaming TTS in wyoming * Add test * Refactor to avoid repeated task creation * Manually manage client lifecycle
This commit is contained in:
@@ -1,13 +1,21 @@
|
||||
"""Support for Wyoming text-to-speech services."""
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import AsyncGenerator
|
||||
import io
|
||||
import logging
|
||||
import wave
|
||||
|
||||
from wyoming.audio import AudioChunk, AudioStop
|
||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.client import AsyncTcpClient
|
||||
from wyoming.tts import Synthesize, SynthesizeVoice
|
||||
from wyoming.tts import (
|
||||
Synthesize,
|
||||
SynthesizeChunk,
|
||||
SynthesizeStart,
|
||||
SynthesizeStop,
|
||||
SynthesizeStopped,
|
||||
SynthesizeVoice,
|
||||
)
|
||||
|
||||
from homeassistant.components import tts
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@@ -45,6 +53,7 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
||||
service: WyomingService,
|
||||
) -> None:
|
||||
"""Set up provider."""
|
||||
self.config_entry = config_entry
|
||||
self.service = service
|
||||
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
|
||||
|
||||
@@ -150,3 +159,98 @@ class WyomingTtsProvider(tts.TextToSpeechEntity):
|
||||
return (None, None)
|
||||
|
||||
return ("wav", data)
|
||||
|
||||
def async_supports_streaming_input(self) -> bool:
|
||||
"""Return if the TTS engine supports streaming input."""
|
||||
return self._tts_service.supports_synthesize_streaming
|
||||
|
||||
async def async_stream_tts_audio(
|
||||
self, request: tts.TTSAudioRequest
|
||||
) -> tts.TTSAudioResponse:
|
||||
"""Generate speech from an incoming message."""
|
||||
voice_name: str | None = request.options.get(tts.ATTR_VOICE)
|
||||
voice_speaker: str | None = request.options.get(ATTR_SPEAKER)
|
||||
voice: SynthesizeVoice | None = None
|
||||
if voice_name is not None:
|
||||
voice = SynthesizeVoice(name=voice_name, speaker=voice_speaker)
|
||||
|
||||
client = AsyncTcpClient(self.service.host, self.service.port)
|
||||
await client.connect()
|
||||
|
||||
# Stream text chunks to client
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._write_tts_message(request.message_gen, client, voice),
|
||||
"wyoming tts write",
|
||||
)
|
||||
|
||||
async def data_gen():
|
||||
# Stream audio bytes from client
|
||||
try:
|
||||
async for data_chunk in self._read_tts_audio(client):
|
||||
yield data_chunk
|
||||
finally:
|
||||
await client.disconnect()
|
||||
|
||||
return tts.TTSAudioResponse("wav", data_gen())
|
||||
|
||||
async def _write_tts_message(
|
||||
self,
|
||||
message_gen: AsyncGenerator[str],
|
||||
client: AsyncTcpClient,
|
||||
voice: SynthesizeVoice | None,
|
||||
) -> None:
|
||||
"""Write text chunks to the client."""
|
||||
try:
|
||||
# Start stream
|
||||
await client.write_event(SynthesizeStart(voice=voice).event())
|
||||
|
||||
# Accumulate entire message for synthesize event.
|
||||
message = ""
|
||||
async for message_chunk in message_gen:
|
||||
message += message_chunk
|
||||
|
||||
await client.write_event(SynthesizeChunk(text=message_chunk).event())
|
||||
|
||||
# Send entire message for backwards compatibility
|
||||
await client.write_event(Synthesize(text=message, voice=voice).event())
|
||||
|
||||
# End stream
|
||||
await client.write_event(SynthesizeStop().event())
|
||||
except (OSError, WyomingError):
|
||||
# Disconnected
|
||||
_LOGGER.warning("Unexpected disconnection from TTS client")
|
||||
|
||||
async def _read_tts_audio(self, client: AsyncTcpClient) -> AsyncGenerator[bytes]:
|
||||
"""Read audio events from the client and yield WAV audio chunks.
|
||||
|
||||
The WAV header is sent first with a frame count of 0 to indicate that
|
||||
we're streaming and don't know the number of frames ahead of time.
|
||||
"""
|
||||
wav_header_sent = False
|
||||
|
||||
try:
|
||||
while event := await client.read_event():
|
||||
if wav_header_sent and AudioChunk.is_type(event.type):
|
||||
# PCM audio
|
||||
yield AudioChunk.from_event(event).audio
|
||||
elif (not wav_header_sent) and AudioStart.is_type(event.type):
|
||||
# WAV header with nframes = 0 for streaming
|
||||
audio_start = AudioStart.from_event(event)
|
||||
with io.BytesIO() as wav_io:
|
||||
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
|
||||
with wav_file:
|
||||
wav_file.setframerate(audio_start.rate)
|
||||
wav_file.setsampwidth(audio_start.width)
|
||||
wav_file.setnchannels(audio_start.channels)
|
||||
|
||||
wav_io.seek(0)
|
||||
yield wav_io.getvalue()
|
||||
|
||||
wav_header_sent = True
|
||||
elif SynthesizeStopped.is_type(event.type):
|
||||
# All TTS audio has been received
|
||||
break
|
||||
except (OSError, WyomingError):
|
||||
# Disconnected
|
||||
_LOGGER.warning("Unexpected disconnection from TTS client")
|
||||
|
Reference in New Issue
Block a user