mirror of
https://github.com/home-assistant/core.git
synced 2026-06-11 11:41:42 +02:00
Allow overriding TTS result stream with media id (#151718)
This commit is contained in:
@@ -12,6 +12,7 @@ import io
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import secrets
|
||||
from time import monotonic
|
||||
@@ -26,6 +27,7 @@ import voluptuous as vol
|
||||
from homeassistant.components import ffmpeg, websocket_api
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.media_source import (
|
||||
async_resolve_media,
|
||||
generate_media_source_id as ms_generate_media_source_id,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@@ -41,6 +43,7 @@ from homeassistant.core import (
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.helpers.network import get_url
|
||||
@@ -125,6 +128,8 @@ KEY_PATTERN = "{0}_{1}_{2}_{3}"
|
||||
|
||||
SCHEMA_SERVICE_CLEAR_CACHE = vol.Schema({})
|
||||
|
||||
FFMPEG_CHUNK_SIZE: Final[int] = 4096
|
||||
|
||||
|
||||
class TTSCache:
|
||||
"""Cached bytes of a TTS result."""
|
||||
@@ -310,28 +315,31 @@ def async_get_text_to_speech_languages(hass: HomeAssistant) -> set[str]:
|
||||
|
||||
async def _async_convert_audio(
|
||||
hass: HomeAssistant,
|
||||
from_extension: str,
|
||||
audio_bytes_gen: AsyncGenerator[bytes],
|
||||
to_extension: str,
|
||||
from_extension: str | None,
|
||||
audio_input: AsyncGenerator[bytes] | str | Path,
|
||||
to_extension: str | None,
|
||||
to_sample_rate: int | None = None,
|
||||
to_sample_channels: int | None = None,
|
||||
to_sample_bytes: int | None = None,
|
||||
) -> AsyncGenerator[bytes]:
|
||||
"""Convert audio to a preferred format using ffmpeg."""
|
||||
ffmpeg_manager = ffmpeg.get_ffmpeg_manager(hass)
|
||||
is_input_gen = not isinstance(audio_input, (str, Path))
|
||||
|
||||
command = [ffmpeg_manager.binary, "-hide_banner", "-loglevel", "error"]
|
||||
if from_extension:
|
||||
command.extend(["-f", from_extension])
|
||||
|
||||
if is_input_gen:
|
||||
# Async generator
|
||||
command.extend(["-i", "pipe:0"])
|
||||
else:
|
||||
# URL or path
|
||||
command.extend(["-i", str(audio_input)])
|
||||
|
||||
if to_extension:
|
||||
command.extend(["-f", to_extension])
|
||||
|
||||
command = [
|
||||
ffmpeg_manager.binary,
|
||||
"-hide_banner",
|
||||
"-loglevel",
|
||||
"error",
|
||||
"-f",
|
||||
from_extension,
|
||||
"-i",
|
||||
"pipe:",
|
||||
"-f",
|
||||
to_extension,
|
||||
]
|
||||
if to_sample_rate is not None:
|
||||
command.extend(["-ar", str(to_sample_rate)])
|
||||
if to_sample_channels is not None:
|
||||
@@ -346,36 +354,44 @@ async def _async_convert_audio(
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*command,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdin=asyncio.subprocess.PIPE if is_input_gen else None,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
async def write_input() -> None:
|
||||
assert process.stdin
|
||||
try:
|
||||
async for chunk in audio_bytes_gen:
|
||||
process.stdin.write(chunk)
|
||||
await process.stdin.drain()
|
||||
finally:
|
||||
if process.stdin:
|
||||
process.stdin.close()
|
||||
writer_task: asyncio.Task | None = None
|
||||
|
||||
writer_task = hass.async_create_background_task(
|
||||
write_input(), "tts_ffmpeg_conversion"
|
||||
)
|
||||
if is_input_gen:
|
||||
# Input is a generator, so we must manually feed in chunks
|
||||
assert isinstance(audio_input, AsyncGenerator)
|
||||
assert process.stdin
|
||||
|
||||
async def write_input() -> None:
|
||||
assert process.stdin
|
||||
try:
|
||||
async for chunk in audio_input:
|
||||
process.stdin.write(chunk)
|
||||
await process.stdin.drain()
|
||||
finally:
|
||||
if process.stdin:
|
||||
process.stdin.close()
|
||||
|
||||
writer_task = hass.async_create_background_task(
|
||||
write_input(), "tts_ffmpeg_conversion"
|
||||
)
|
||||
|
||||
assert process.stdout
|
||||
chunk_size = 4096
|
||||
try:
|
||||
while True:
|
||||
chunk = await process.stdout.read(chunk_size)
|
||||
chunk = await process.stdout.read(FFMPEG_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
finally:
|
||||
# Ensure we wait for the input writer to complete.
|
||||
await writer_task
|
||||
if writer_task is not None:
|
||||
# Ensure we wait for the input writer to complete.
|
||||
await writer_task
|
||||
|
||||
# Wait for process termination and check for errors.
|
||||
retcode = await process.wait()
|
||||
if retcode != 0:
|
||||
@@ -470,6 +486,7 @@ class ResultStream:
|
||||
"""Class that will stream the result when available."""
|
||||
|
||||
last_used: float = field(default_factory=monotonic, init=False)
|
||||
hass: HomeAssistant
|
||||
|
||||
# Streaming/conversion properties
|
||||
token: str
|
||||
@@ -485,6 +502,9 @@ class ResultStream:
|
||||
|
||||
_manager: SpeechManager
|
||||
|
||||
# Override
|
||||
_override_media_id: str | None = None
|
||||
|
||||
@cached_property
|
||||
def url(self) -> str:
|
||||
"""Get the URL to stream the result."""
|
||||
@@ -536,12 +556,64 @@ class ResultStream:
|
||||
|
||||
async def async_stream_result(self) -> AsyncGenerator[bytes]:
|
||||
"""Get the stream of this result."""
|
||||
if self._override_media_id is not None:
|
||||
# Overridden
|
||||
async for chunk in self._async_stream_override_result():
|
||||
yield chunk
|
||||
|
||||
self.last_used = monotonic()
|
||||
return
|
||||
|
||||
cache = await self._result_cache
|
||||
async for chunk in cache.async_stream_data():
|
||||
yield chunk
|
||||
|
||||
self.last_used = monotonic()
|
||||
|
||||
def async_override_result(self, media_id: str) -> None:
|
||||
"""Override the TTS stream with a different media id."""
|
||||
self._override_media_id = media_id
|
||||
|
||||
async def _async_stream_override_result(self) -> AsyncGenerator[bytes]:
|
||||
"""Get the stream of the overridden result."""
|
||||
assert self._override_media_id is not None
|
||||
media = await async_resolve_media(self.hass, self._override_media_id)
|
||||
|
||||
# Determine if we need to do audio conversion
|
||||
preferred_extension: str | None = self.options.get(ATTR_PREFERRED_FORMAT)
|
||||
sample_rate: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_RATE)
|
||||
sample_channels: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_CHANNELS)
|
||||
sample_bytes: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_BYTES)
|
||||
|
||||
needs_conversion = (
|
||||
preferred_extension
|
||||
or (sample_rate is not None)
|
||||
or (sample_channels is not None)
|
||||
or (sample_bytes is not None)
|
||||
)
|
||||
|
||||
if not needs_conversion:
|
||||
# Stream directly from URL (no conversion)
|
||||
session = async_get_clientsession(self.hass)
|
||||
async with session.get(media.url) as response:
|
||||
async for chunk in response.content:
|
||||
yield chunk
|
||||
|
||||
return
|
||||
|
||||
# Use ffmpeg to convert audio to preferred format
|
||||
converted_audio = _async_convert_audio(
|
||||
self.hass,
|
||||
from_extension=None,
|
||||
audio_input=media.path or media.url,
|
||||
to_extension=preferred_extension,
|
||||
to_sample_rate=sample_rate,
|
||||
to_sample_channels=sample_channels,
|
||||
to_sample_bytes=sample_bytes,
|
||||
)
|
||||
async for chunk in converted_audio:
|
||||
yield chunk
|
||||
|
||||
|
||||
def _hash_options(options: dict) -> str:
|
||||
"""Hashes an options dictionary."""
|
||||
@@ -773,6 +845,7 @@ class SpeechManager:
|
||||
language=language,
|
||||
options=options,
|
||||
supports_streaming_input=supports_streaming_input,
|
||||
hass=self.hass,
|
||||
_manager=self,
|
||||
)
|
||||
self.token_to_stream[token] = result_stream
|
||||
|
||||
@@ -285,6 +285,7 @@ class MockResultStream(ResultStream):
|
||||
supports_streaming_input=True,
|
||||
language="en",
|
||||
options={},
|
||||
hass=hass,
|
||||
_manager=hass.data[DATA_TTS_MANAGER],
|
||||
)
|
||||
hass.data[DATA_TTS_MANAGER].token_to_stream[self.token] = self
|
||||
|
||||
@@ -2,14 +2,17 @@
|
||||
|
||||
import asyncio
|
||||
from http import HTTPStatus
|
||||
import io
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
import wave
|
||||
|
||||
from freezegun.api import FrozenDateTimeFactory
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import ffmpeg, tts
|
||||
from homeassistant.components import ffmpeg, media_source, tts
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_ANNOUNCE,
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
@@ -40,6 +43,7 @@ from .common import (
|
||||
)
|
||||
|
||||
from tests.common import MockModule, async_mock_service, mock_integration, mock_platform
|
||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
|
||||
ORIG_WRITE_TAGS = tts.SpeechManager.write_tags
|
||||
@@ -2063,3 +2067,122 @@ async def test_async_internal_get_tts_audio_called(
|
||||
|
||||
# async_internal_get_tts_audio is called
|
||||
internal_get_tts_audio.assert_called_once_with("test message", "en_US", {})
|
||||
|
||||
|
||||
async def test_stream_override(
|
||||
hass: HomeAssistant,
|
||||
mock_tts_entity: MockTTSEntity,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
) -> None:
|
||||
"""Test overriding streams with a media id."""
|
||||
await mock_config_entry_setup(hass, mock_tts_entity)
|
||||
|
||||
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
|
||||
stream.async_set_message("beer")
|
||||
stream.async_override_result("test-media-id")
|
||||
|
||||
url = "http://www.home-assistant.io/resolved.mp3"
|
||||
test_data = b"override-data"
|
||||
aioclient_mock.get(url, content=test_data)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_resolve_media",
|
||||
return_value=media_source.PlayMedia(url=url, mime_type="audio/mp3"),
|
||||
):
|
||||
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||
assert result_data == test_data
|
||||
|
||||
|
||||
async def test_stream_override_with_conversion(
|
||||
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
|
||||
) -> None:
|
||||
"""Test overriding streams with a media id that requires conversion."""
|
||||
await mock_config_entry_setup(hass, mock_tts_entity)
|
||||
|
||||
stream = tts.async_create_stream(
|
||||
hass,
|
||||
mock_tts_entity.entity_id,
|
||||
options={
|
||||
tts.ATTR_PREFERRED_FORMAT: "wav",
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE: 22050,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
|
||||
},
|
||||
)
|
||||
stream.async_set_message("beer")
|
||||
stream.async_override_result("test-media-id")
|
||||
|
||||
# Use a temp file here since ffmpeg will read it directly
|
||||
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as wav_file:
|
||||
with wave.open(wav_file, "wb") as wav_writer:
|
||||
wav_writer.setframerate(16000)
|
||||
wav_writer.setsampwidth(2)
|
||||
wav_writer.setnchannels(1)
|
||||
wav_writer.writeframes(bytes(16000 * 2)) # 1 second @ 16Khz/mono
|
||||
|
||||
wav_file.seek(0)
|
||||
|
||||
url = f"file://{wav_file.name}"
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_resolve_media",
|
||||
return_value=media_source.PlayMedia(url=url, mime_type="audio/wav"),
|
||||
):
|
||||
result_data = b"".join(
|
||||
[chunk async for chunk in stream.async_stream_result()]
|
||||
)
|
||||
|
||||
# Verify the preferred format
|
||||
with io.BytesIO(result_data) as wav_io, wave.open(wav_io, "rb") as wav_reader:
|
||||
assert wav_reader.getframerate() == 22050
|
||||
assert wav_reader.getsampwidth() == 2
|
||||
assert wav_reader.getnchannels() == 2
|
||||
assert wav_reader.readframes(wav_reader.getnframes()) == bytes(
|
||||
22050 * 2 * 2
|
||||
) # 1 second @ 22.5Khz/stereo
|
||||
|
||||
|
||||
async def test_stream_override_with_conversion_path_preferred(
|
||||
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
|
||||
) -> None:
|
||||
"""Test overriding streams with a media id that requires conversion and has a path."""
|
||||
await mock_config_entry_setup(hass, mock_tts_entity)
|
||||
|
||||
stream = tts.async_create_stream(
|
||||
hass,
|
||||
mock_tts_entity.entity_id,
|
||||
options={tts.ATTR_PREFERRED_FORMAT: "wav"},
|
||||
)
|
||||
stream.async_set_message("beer")
|
||||
stream.async_override_result("test-media-id")
|
||||
|
||||
# Use a temp file here since ffmpeg will read it directly
|
||||
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as wav_file:
|
||||
with wave.open(wav_file, "wb") as wav_writer:
|
||||
wav_writer.setframerate(16000)
|
||||
wav_writer.setsampwidth(2)
|
||||
wav_writer.setnchannels(1)
|
||||
wav_writer.writeframes(bytes(16000 * 2)) # 1 second @ 16Khz/mono
|
||||
|
||||
wav_file.seek(0)
|
||||
|
||||
# Path is preferred over URL
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_resolve_media",
|
||||
return_value=media_source.PlayMedia(
|
||||
path=Path(wav_file.name),
|
||||
url="http://bad-url.com",
|
||||
mime_type="audio/wav",
|
||||
),
|
||||
):
|
||||
result_data = b"".join(
|
||||
[chunk async for chunk in stream.async_stream_result()]
|
||||
)
|
||||
|
||||
# Verify the preferred format
|
||||
with io.BytesIO(result_data) as wav_io, wave.open(wav_io, "rb") as wav_reader:
|
||||
assert wav_reader.getframerate() == 16000
|
||||
assert wav_reader.getsampwidth() == 2
|
||||
assert wav_reader.getnchannels() == 1
|
||||
assert wav_reader.readframes(wav_reader.getnframes()) == bytes(
|
||||
16000 * 2
|
||||
) # 1 second @ 16Khz/mono
|
||||
|
||||
Reference in New Issue
Block a user