Handle non-streaming TTS case correctly (#150218)

This commit is contained in:
Michael Hansen
2025-08-11 11:47:29 -05:00
committed by GitHub
parent cb7c7767b5
commit 1a9d1a9649
5 changed files with 77 additions and 81 deletions

View File

@@ -976,11 +976,15 @@ class SpeechManager:
if engine_instance.name is None or engine_instance.name is UNDEFINED: if engine_instance.name is None or engine_instance.name is UNDEFINED:
raise HomeAssistantError("TTS engine name is not set.") raise HomeAssistantError("TTS engine name is not set.")
if isinstance(engine_instance, Provider): if isinstance(engine_instance, Provider) or (
not engine_instance.async_supports_streaming_input()
):
# Non-streaming
if isinstance(message_or_stream, str): if isinstance(message_or_stream, str):
message = message_or_stream message = message_or_stream
else: else:
message = "".join([chunk async for chunk in message_or_stream]) message = "".join([chunk async for chunk in message_or_stream])
extension, data = await engine_instance.async_internal_get_tts_audio( extension, data = await engine_instance.async_internal_get_tts_audio(
message, language, options message, language, options
) )
@@ -996,6 +1000,7 @@ class SpeechManager:
data_gen = make_data_generator(data) data_gen = make_data_generator(data)
else: else:
# Streaming
if isinstance(message_or_stream, str): if isinstance(message_or_stream, str):
async def gen_stream() -> AsyncGenerator[str]: async def gen_stream() -> AsyncGenerator[str]:

View File

@@ -191,6 +191,18 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
"""Load tts audio file from the engine.""" """Load tts audio file from the engine."""
raise NotImplementedError raise NotImplementedError
@final
async def async_internal_get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load tts audio file from the engine and update state.
Return a tuple of file extension and data as bytes.
"""
self.__last_tts_loaded = dt_util.utcnow().isoformat()
self.async_write_ha_state()
return await self.async_get_tts_audio(message, language, options=options)
async def async_get_tts_audio( async def async_get_tts_audio(
self, message: str, language: str, options: dict[str, Any] self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType: ) -> TtsAudioType:

View File

@@ -175,3 +175,31 @@ def test_streaming_supported() -> None:
sync_non_streaming_entity = SyncNonStreamingEntity() sync_non_streaming_entity = SyncNonStreamingEntity()
assert sync_non_streaming_entity.async_supports_streaming_input() is False assert sync_non_streaming_entity.async_supports_streaming_input() is False
async def test_internal_get_tts_audio_writes_state(
hass: HomeAssistant,
mock_tts_entity: MockTTSEntity,
) -> None:
"""Test that only async_internal_get_tts_audio updates and writes the state."""
entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}"
config_entry = await mock_config_entry_setup(hass, mock_tts_entity)
assert config_entry.state is ConfigEntryState.LOADED
state1 = hass.states.get(entity_id)
assert state1 is not None
# State should *not* change with external method
await mock_tts_entity.async_get_tts_audio("test message", hass.config.language, {})
state2 = hass.states.get(entity_id)
assert state2 is not None
assert state1.state == state2.state
# State *should* change with internal method
await mock_tts_entity.async_internal_get_tts_audio(
"test message", hass.config.language, {}
)
state3 = hass.states.get(entity_id)
assert state3 is not None
assert state1.state != state3.state

View File

@@ -2032,3 +2032,34 @@ async def test_tts_cache() -> None:
assert await consume_mid_data_task == b"012" assert await consume_mid_data_task == b"012"
with pytest.raises(ValueError): with pytest.raises(ValueError):
assert await consume_pre_data_loaded_task == b"012" assert await consume_pre_data_loaded_task == b"012"
async def test_async_internal_get_tts_audio_called(
hass: HomeAssistant,
mock_tts_entity: MockTTSEntity,
hass_client: ClientSessionGenerator,
) -> None:
"""Test that non-streaming entity has its async_internal_get_tts_audio method called."""
await mock_config_entry_setup(hass, mock_tts_entity)
# Non-streaming
assert mock_tts_entity.async_supports_streaming_input() is False
with patch(
"homeassistant.components.tts.entity.TextToSpeechEntity.async_internal_get_tts_audio"
) as internal_get_tts_audio:
media_source_id = tts.generate_media_source_id(
hass,
"test message",
"tts.test",
"en_US",
cache=None,
)
url = await get_media_source_url(hass, media_source_id)
client = await hass_client()
await client.get(url)
# async_internal_get_tts_audio is called
internal_get_tts_audio.assert_called_once_with("test message", "en_US", {})

View File

@@ -1,19 +1,6 @@
# serializer version: 1 # serializer version: 1
# name: test_get_tts_audio # name: test_get_tts_audio
list([ list([
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({ dict({
'data': dict({ 'data': dict({
'text': 'Hello world', 'text': 'Hello world',
@@ -21,29 +8,10 @@
'payload': None, 'payload': None,
'type': 'synthesize', 'type': 'synthesize',
}), }),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
]) ])
# --- # ---
# name: test_get_tts_audio_different_formats # name: test_get_tts_audio_different_formats
list([ list([
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({ dict({
'data': dict({ 'data': dict({
'text': 'Hello world', 'text': 'Hello world',
@@ -51,29 +19,10 @@
'payload': None, 'payload': None,
'type': 'synthesize', 'type': 'synthesize',
}), }),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
]) ])
# --- # ---
# name: test_get_tts_audio_different_formats.1 # name: test_get_tts_audio_different_formats.1
list([ list([
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({ dict({
'data': dict({ 'data': dict({
'text': 'Hello world', 'text': 'Hello world',
@@ -81,12 +30,6 @@
'payload': None, 'payload': None,
'type': 'synthesize', 'type': 'synthesize',
}), }),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
]) ])
# --- # ---
# name: test_get_tts_audio_streaming # name: test_get_tts_audio_streaming
@@ -128,23 +71,6 @@
# --- # ---
# name: test_voice_speaker # name: test_voice_speaker
list([ list([
dict({
'data': dict({
'voice': dict({
'name': 'voice1',
'speaker': 'speaker1',
}),
}),
'payload': None,
'type': 'synthesize-start',
}),
dict({
'data': dict({
'text': 'Hello world',
}),
'payload': None,
'type': 'synthesize-chunk',
}),
dict({ dict({
'data': dict({ 'data': dict({
'text': 'Hello world', 'text': 'Hello world',
@@ -156,11 +82,5 @@
'payload': None, 'payload': None,
'type': 'synthesize', 'type': 'synthesize',
}), }),
dict({
'data': dict({
}),
'payload': None,
'type': 'synthesize-stop',
}),
]) ])
# --- # ---