Allow overriding TTS result stream with media id (#151718)

This commit is contained in:
Michael Hansen
2025-09-09 22:30:00 -05:00
committed by GitHub
parent 7a332d489d
commit e8d5615e54
3 changed files with 230 additions and 33 deletions
+105 -32
View File
@@ -12,6 +12,7 @@ import io
import logging
import mimetypes
import os
from pathlib import Path
import re
import secrets
from time import monotonic
@@ -26,6 +27,7 @@ import voluptuous as vol
from homeassistant.components import ffmpeg, websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_source import (
async_resolve_media,
generate_media_source_id as ms_generate_media_source_id,
)
from homeassistant.config_entries import ConfigEntry
@@ -41,6 +43,7 @@ from homeassistant.core import (
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.network import get_url
@@ -125,6 +128,8 @@ KEY_PATTERN = "{0}_{1}_{2}_{3}"
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
FFMPEG_CHUNK_SIZE: Final[int] = 4096
class TTSCache:
"""Cached bytes of a TTS result."""
@@ -310,28 +315,31 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
async def _async_convert_audio(
hass: HomeAssistant,
from_extension: str,
audio_bytes_gen: AsyncGenerator[bytes],
to_extension: str,
from_extension: str | None,
audio_input: AsyncGenerator[bytes] | str | Path,
to_extension: str | None,
to_sample_rate: int | None = None,
to_sample_channels: int | None = None,
to_sample_bytes: int | None = None,
) -> AsyncGenerator[bytes]:
"""Convert audio to a preferred format using ffmpeg."""
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
is_input_gen = not isinstance(audio_input, (str, Path))
command = [ffmpeg_manager.binary, "-hide_banner", "-loglevel", "error"]
if from_extension:
command.extend(["-f", from_extension])
if is_input_gen:
# Async generator
command.extend(["-i", "pipe:0"])
else:
# URL or path
command.extend(["-i", str(audio_input)])
if to_extension:
command.extend(["-f", to_extension])
command = [
ffmpeg_manager.binary,
"-hide_banner",
"-loglevel",
"error",
"-f",
from_extension,
"-i",
"pipe:",
"-f",
to_extension,
]
if to_sample_rate is not None:
command.extend(["-ar", str(to_sample_rate)])
if to_sample_channels is not None:
@@ -346,36 +354,44 @@ async def _async_convert_audio(
process = await asyncio.create_subprocess_exec(
*command,
stdin=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE if is_input_gen else None,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
async def write_input() -> None:
assert process.stdin
try:
async for chunk in audio_bytes_gen:
process.stdin.write(chunk)
await process.stdin.drain()
finally:
if process.stdin:
process.stdin.close()
writer_task: asyncio.Task | None = None
writer_task = hass.async_create_background_task(
write_input(), "tts_ffmpeg_conversion"
)
if is_input_gen:
# Input is a generator, so we must manually feed in chunks
assert isinstance(audio_input, AsyncGenerator)
assert process.stdin
async def write_input() -> None:
assert process.stdin
try:
async for chunk in audio_input:
process.stdin.write(chunk)
await process.stdin.drain()
finally:
if process.stdin:
process.stdin.close()
writer_task = hass.async_create_background_task(
write_input(), "tts_ffmpeg_conversion"
)
assert process.stdout
chunk_size = 4096
try:
while True:
chunk = await process.stdout.read(chunk_size)
chunk = await process.stdout.read(FFMPEG_CHUNK_SIZE)
if not chunk:
break
yield chunk
finally:
# Ensure we wait for the input writer to complete.
await writer_task
if writer_task is not None:
# Ensure we wait for the input writer to complete.
await writer_task
# Wait for process termination and check for errors.
retcode = await process.wait()
if retcode != 0:
@@ -470,6 +486,7 @@ class ResultStream:
"""Class that will stream the result when available."""
last_used: float = field(default_factory=monotonic, init=False)
hass: HomeAssistant
# Streaming/conversion properties
token: str
@@ -485,6 +502,9 @@ class ResultStream:
_manager: SpeechManager
# Override
_override_media_id: str | None = None
@cached_property
def url(self) -> str:
"""Get the URL to stream the result."""
@@ -536,12 +556,64 @@ class ResultStream:
async def async_stream_result(self) -> AsyncGenerator[bytes]:
"""Get the stream of this result."""
if self._override_media_id is not None:
# Overridden
async for chunk in self._async_stream_override_result():
yield chunk
self.last_used = monotonic()
return
cache = await self._result_cache
async for chunk in cache.async_stream_data():
yield chunk
self.last_used = monotonic()
def async_override_result(self, media_id: str) -> None:
"""Override the TTS stream with a different media id."""
self._override_media_id = media_id
async def _async_stream_override_result(self) -> AsyncGenerator[bytes]:
"""Get the stream of the overridden result."""
assert self._override_media_id is not None
media = await async_resolve_media(self.hass, self._override_media_id)
# Determine if we need to do audio conversion
preferred_extension: str | None = self.options.get(ATTR_PREFERRED_FORMAT)
sample_rate: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_RATE)
sample_channels: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_CHANNELS)
sample_bytes: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_BYTES)
needs_conversion = (
preferred_extension
or (sample_rate is not None)
or (sample_channels is not None)
or (sample_bytes is not None)
)
if not needs_conversion:
# Stream directly from URL (no conversion)
session = async_get_clientsession(self.hass)
async with session.get(media.url) as response:
async for chunk in response.content:
yield chunk
return
# Use ffmpeg to convert audio to preferred format
converted_audio = _async_convert_audio(
self.hass,
from_extension=None,
audio_input=media.path or media.url,
to_extension=preferred_extension,
to_sample_rate=sample_rate,
to_sample_channels=sample_channels,
to_sample_bytes=sample_bytes,
)
async for chunk in converted_audio:
yield chunk
def _hash_options(options: dict) -> str:
"""Hashes an options dictionary."""
@@ -773,6 +845,7 @@ class SpeechManager:
language=language,
options=options,
supports_streaming_input=supports_streaming_input,
hass=self.hass,
_manager=self,
)
self.token_to_stream[token] = result_stream
+1
View File
@@ -285,6 +285,7 @@ class MockResultStream(ResultStream):
supports_streaming_input=True,
language="en",
options={},
hass=hass,
_manager=hass.data[DATA_TTS_MANAGER],
)
hass.data[DATA_TTS_MANAGER].token_to_stream[self.token] = self
+124 -1
View File
@@ -2,14 +2,17 @@
import asyncio
from http import HTTPStatus
import io
from pathlib import Path
import tempfile
from typing import Any
from unittest.mock import MagicMock, Mock, patch
import wave
from freezegun.api import FrozenDateTimeFactory
import pytest
from homeassistant.components import ffmpeg, tts
from homeassistant.components import ffmpeg, media_source, tts
from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
ATTR_MEDIA_CONTENT_ID,
@@ -40,6 +43,7 @@ from .common import (
)
from tests.common import MockModule, async_mock_service, mock_integration, mock_platform
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator, WebSocketGenerator
ORIG_WRITE_TAGS = tts.SpeechManager.write_tags
@@ -2063,3 +2067,122 @@ async def test_async_internal_get_tts_audio_called(
# async_internal_get_tts_audio is called
internal_get_tts_audio.assert_called_once_with("test message", "en_US", {})
async def test_stream_override(
hass: HomeAssistant,
mock_tts_entity: MockTTSEntity,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Test overriding streams with a media id."""
await mock_config_entry_setup(hass, mock_tts_entity)
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
stream.async_set_message("beer")
stream.async_override_result("test-media-id")
url = "http://www.home-assistant.io/resolved.mp3"
test_data = b"override-data"
aioclient_mock.get(url, content=test_data)
with patch(
"homeassistant.components.tts.async_resolve_media",
return_value=media_source.PlayMedia(url=url, mime_type="audio/mp3"),
):
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
assert result_data == test_data
async def test_stream_override_with_conversion(
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
) -> None:
"""Test overriding streams with a media id that requires conversion."""
await mock_config_entry_setup(hass, mock_tts_entity)
stream = tts.async_create_stream(
hass,
mock_tts_entity.entity_id,
options={
tts.ATTR_PREFERRED_FORMAT: "wav",
tts.ATTR_PREFERRED_SAMPLE_RATE: 22050,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
},
)
stream.async_set_message("beer")
stream.async_override_result("test-media-id")
# Use a temp file here since ffmpeg will read it directly
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as wav_file:
with wave.open(wav_file, "wb") as wav_writer:
wav_writer.setframerate(16000)
wav_writer.setsampwidth(2)
wav_writer.setnchannels(1)
wav_writer.writeframes(bytes(16000 * 2)) # 1 second @ 16Khz/mono
wav_file.seek(0)
url = f"file://{wav_file.name}"
with patch(
"homeassistant.components.tts.async_resolve_media",
return_value=media_source.PlayMedia(url=url, mime_type="audio/wav"),
):
result_data = b"".join(
[chunk async for chunk in stream.async_stream_result()]
)
# Verify the preferred format
with io.BytesIO(result_data) as wav_io, wave.open(wav_io, "rb") as wav_reader:
assert wav_reader.getframerate() == 22050
assert wav_reader.getsampwidth() == 2
assert wav_reader.getnchannels() == 2
assert wav_reader.readframes(wav_reader.getnframes()) == bytes(
22050 * 2 * 2
) # 1 second @ 22.5Khz/stereo
async def test_stream_override_with_conversion_path_preferred(
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
) -> None:
"""Test overriding streams with a media id that requires conversion and has a path."""
await mock_config_entry_setup(hass, mock_tts_entity)
stream = tts.async_create_stream(
hass,
mock_tts_entity.entity_id,
options={tts.ATTR_PREFERRED_FORMAT: "wav"},
)
stream.async_set_message("beer")
stream.async_override_result("test-media-id")
# Use a temp file here since ffmpeg will read it directly
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as wav_file:
with wave.open(wav_file, "wb") as wav_writer:
wav_writer.setframerate(16000)
wav_writer.setsampwidth(2)
wav_writer.setnchannels(1)
wav_writer.writeframes(bytes(16000 * 2)) # 1 second @ 16Khz/mono
wav_file.seek(0)
# Path is preferred over URL
with patch(
"homeassistant.components.tts.async_resolve_media",
return_value=media_source.PlayMedia(
path=Path(wav_file.name),
url="http://bad-url.com",
mime_type="audio/wav",
),
):
result_data = b"".join(
[chunk async for chunk in stream.async_stream_result()]
)
# Verify the preferred format
with io.BytesIO(result_data) as wav_io, wave.open(wav_io, "rb") as wav_reader:
assert wav_reader.getframerate() == 16000
assert wav_reader.getsampwidth() == 2
assert wav_reader.getnchannels() == 1
assert wav_reader.readframes(wav_reader.getnframes()) == bytes(
16000 * 2
) # 1 second @ 16Khz/mono