mirror of
https://github.com/home-assistant/core.git
synced 2025-08-30 01:42:21 +02:00
Handle non-streaming TTS case correctly (#150218)
This commit is contained in:
@@ -976,11 +976,15 @@ class SpeechManager:
|
||||
if engine_instance.name is None or engine_instance.name is UNDEFINED:
|
||||
raise HomeAssistantError("TTS engine name is not set.")
|
||||
|
||||
if isinstance(engine_instance, Provider):
|
||||
if isinstance(engine_instance, Provider) or (
|
||||
not engine_instance.async_supports_streaming_input()
|
||||
):
|
||||
# Non-streaming
|
||||
if isinstance(message_or_stream, str):
|
||||
message = message_or_stream
|
||||
else:
|
||||
message = "".join([chunk async for chunk in message_or_stream])
|
||||
|
||||
extension, data = await engine_instance.async_internal_get_tts_audio(
|
||||
message, language, options
|
||||
)
|
||||
@@ -996,6 +1000,7 @@ class SpeechManager:
|
||||
data_gen = make_data_generator(data)
|
||||
|
||||
else:
|
||||
# Streaming
|
||||
if isinstance(message_or_stream, str):
|
||||
|
||||
async def gen_stream() -> AsyncGenerator[str]:
|
||||
|
@@ -191,6 +191,18 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
|
||||
"""Load tts audio file from the engine."""
|
||||
raise NotImplementedError
|
||||
|
||||
@final
|
||||
async def async_internal_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from the engine and update state.
|
||||
|
||||
Return a tuple of file extension and data as bytes.
|
||||
"""
|
||||
self.__last_tts_loaded = dt_util.utcnow().isoformat()
|
||||
self.async_write_ha_state()
|
||||
return await self.async_get_tts_audio(message, language, options=options)
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any]
|
||||
) -> TtsAudioType:
|
||||
|
@@ -175,3 +175,31 @@ def test_streaming_supported() -> None:
|
||||
|
||||
sync_non_streaming_entity = SyncNonStreamingEntity()
|
||||
assert sync_non_streaming_entity.async_supports_streaming_input() is False
|
||||
|
||||
|
||||
async def test_internal_get_tts_audio_writes_state(
|
||||
hass: HomeAssistant,
|
||||
mock_tts_entity: MockTTSEntity,
|
||||
) -> None:
|
||||
"""Test that only async_internal_get_tts_audio updates and writes the state."""
|
||||
|
||||
entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}"
|
||||
|
||||
config_entry = await mock_config_entry_setup(hass, mock_tts_entity)
|
||||
assert config_entry.state is ConfigEntryState.LOADED
|
||||
state1 = hass.states.get(entity_id)
|
||||
assert state1 is not None
|
||||
|
||||
# State should *not* change with external method
|
||||
await mock_tts_entity.async_get_tts_audio("test message", hass.config.language, {})
|
||||
state2 = hass.states.get(entity_id)
|
||||
assert state2 is not None
|
||||
assert state1.state == state2.state
|
||||
|
||||
# State *should* change with internal method
|
||||
await mock_tts_entity.async_internal_get_tts_audio(
|
||||
"test message", hass.config.language, {}
|
||||
)
|
||||
state3 = hass.states.get(entity_id)
|
||||
assert state3 is not None
|
||||
assert state1.state != state3.state
|
||||
|
@@ -2032,3 +2032,34 @@ async def test_tts_cache() -> None:
|
||||
assert await consume_mid_data_task == b"012"
|
||||
with pytest.raises(ValueError):
|
||||
assert await consume_pre_data_loaded_task == b"012"
|
||||
|
||||
|
||||
async def test_async_internal_get_tts_audio_called(
|
||||
hass: HomeAssistant,
|
||||
mock_tts_entity: MockTTSEntity,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test that non-streaming entity has its async_internal_get_tts_audio method called."""
|
||||
|
||||
await mock_config_entry_setup(hass, mock_tts_entity)
|
||||
|
||||
# Non-streaming
|
||||
assert mock_tts_entity.async_supports_streaming_input() is False
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.tts.entity.TextToSpeechEntity.async_internal_get_tts_audio"
|
||||
) as internal_get_tts_audio:
|
||||
media_source_id = tts.generate_media_source_id(
|
||||
hass,
|
||||
"test message",
|
||||
"tts.test",
|
||||
"en_US",
|
||||
cache=None,
|
||||
)
|
||||
|
||||
url = await get_media_source_url(hass, media_source_id)
|
||||
client = await hass_client()
|
||||
await client.get(url)
|
||||
|
||||
# async_internal_get_tts_audio is called
|
||||
internal_get_tts_audio.assert_called_once_with("test message", "en_US", {})
|
||||
|
@@ -1,19 +1,6 @@
|
||||
# serializer version: 1
|
||||
# name: test_get_tts_audio
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-start',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
@@ -21,29 +8,10 @@
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-stop',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_different_formats
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-start',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
@@ -51,29 +19,10 @@
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-stop',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_different_formats.1
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-start',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
@@ -81,12 +30,6 @@
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-stop',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_streaming
|
||||
@@ -128,23 +71,6 @@
|
||||
# ---
|
||||
# name: test_voice_speaker
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'voice': dict({
|
||||
'name': 'voice1',
|
||||
'speaker': 'speaker1',
|
||||
}),
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-start',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello world',
|
||||
@@ -156,11 +82,5 @@
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-stop',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
|
Reference in New Issue
Block a user