mirror of
https://github.com/home-assistant/core.git
synced 2025-08-30 09:51:37 +02:00
Handle non-streaming TTS case correctly (#150218)
This commit is contained in:
@@ -976,11 +976,15 @@ class SpeechManager:
|
|||||||
if engine_instance.name is None or engine_instance.name is UNDEFINED:
|
if engine_instance.name is None or engine_instance.name is UNDEFINED:
|
||||||
raise HomeAssistantError("TTS engine name is not set.")
|
raise HomeAssistantError("TTS engine name is not set.")
|
||||||
|
|
||||||
if isinstance(engine_instance, Provider):
|
if isinstance(engine_instance, Provider) or (
|
||||||
|
not engine_instance.async_supports_streaming_input()
|
||||||
|
):
|
||||||
|
# Non-streaming
|
||||||
if isinstance(message_or_stream, str):
|
if isinstance(message_or_stream, str):
|
||||||
message = message_or_stream
|
message = message_or_stream
|
||||||
else:
|
else:
|
||||||
message = "".join([chunk async for chunk in message_or_stream])
|
message = "".join([chunk async for chunk in message_or_stream])
|
||||||
|
|
||||||
extension, data = await engine_instance.async_internal_get_tts_audio(
|
extension, data = await engine_instance.async_internal_get_tts_audio(
|
||||||
message, language, options
|
message, language, options
|
||||||
)
|
)
|
||||||
@@ -996,6 +1000,7 @@ class SpeechManager:
|
|||||||
data_gen = make_data_generator(data)
|
data_gen = make_data_generator(data)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
# Streaming
|
||||||
if isinstance(message_or_stream, str):
|
if isinstance(message_or_stream, str):
|
||||||
|
|
||||||
async def gen_stream() -> AsyncGenerator[str]:
|
async def gen_stream() -> AsyncGenerator[str]:
|
||||||
|
@@ -191,6 +191,18 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH
|
|||||||
"""Load tts audio file from the engine."""
|
"""Load tts audio file from the engine."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@final
|
||||||
|
async def async_internal_get_tts_audio(
|
||||||
|
self, message: str, language: str, options: dict[str, Any]
|
||||||
|
) -> TtsAudioType:
|
||||||
|
"""Load tts audio file from the engine and update state.
|
||||||
|
|
||||||
|
Return a tuple of file extension and data as bytes.
|
||||||
|
"""
|
||||||
|
self.__last_tts_loaded = dt_util.utcnow().isoformat()
|
||||||
|
self.async_write_ha_state()
|
||||||
|
return await self.async_get_tts_audio(message, language, options=options)
|
||||||
|
|
||||||
async def async_get_tts_audio(
|
async def async_get_tts_audio(
|
||||||
self, message: str, language: str, options: dict[str, Any]
|
self, message: str, language: str, options: dict[str, Any]
|
||||||
) -> TtsAudioType:
|
) -> TtsAudioType:
|
||||||
|
@@ -175,3 +175,31 @@ def test_streaming_supported() -> None:
|
|||||||
|
|
||||||
sync_non_streaming_entity = SyncNonStreamingEntity()
|
sync_non_streaming_entity = SyncNonStreamingEntity()
|
||||||
assert sync_non_streaming_entity.async_supports_streaming_input() is False
|
assert sync_non_streaming_entity.async_supports_streaming_input() is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_internal_get_tts_audio_writes_state(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_tts_entity: MockTTSEntity,
|
||||||
|
) -> None:
|
||||||
|
"""Test that only async_internal_get_tts_audio updates and writes the state."""
|
||||||
|
|
||||||
|
entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}"
|
||||||
|
|
||||||
|
config_entry = await mock_config_entry_setup(hass, mock_tts_entity)
|
||||||
|
assert config_entry.state is ConfigEntryState.LOADED
|
||||||
|
state1 = hass.states.get(entity_id)
|
||||||
|
assert state1 is not None
|
||||||
|
|
||||||
|
# State should *not* change with external method
|
||||||
|
await mock_tts_entity.async_get_tts_audio("test message", hass.config.language, {})
|
||||||
|
state2 = hass.states.get(entity_id)
|
||||||
|
assert state2 is not None
|
||||||
|
assert state1.state == state2.state
|
||||||
|
|
||||||
|
# State *should* change with internal method
|
||||||
|
await mock_tts_entity.async_internal_get_tts_audio(
|
||||||
|
"test message", hass.config.language, {}
|
||||||
|
)
|
||||||
|
state3 = hass.states.get(entity_id)
|
||||||
|
assert state3 is not None
|
||||||
|
assert state1.state != state3.state
|
||||||
|
@@ -2032,3 +2032,34 @@ async def test_tts_cache() -> None:
|
|||||||
assert await consume_mid_data_task == b"012"
|
assert await consume_mid_data_task == b"012"
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
assert await consume_pre_data_loaded_task == b"012"
|
assert await consume_pre_data_loaded_task == b"012"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_internal_get_tts_audio_called(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_tts_entity: MockTTSEntity,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test that non-streaming entity has its async_internal_get_tts_audio method called."""
|
||||||
|
|
||||||
|
await mock_config_entry_setup(hass, mock_tts_entity)
|
||||||
|
|
||||||
|
# Non-streaming
|
||||||
|
assert mock_tts_entity.async_supports_streaming_input() is False
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.tts.entity.TextToSpeechEntity.async_internal_get_tts_audio"
|
||||||
|
) as internal_get_tts_audio:
|
||||||
|
media_source_id = tts.generate_media_source_id(
|
||||||
|
hass,
|
||||||
|
"test message",
|
||||||
|
"tts.test",
|
||||||
|
"en_US",
|
||||||
|
cache=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
url = await get_media_source_url(hass, media_source_id)
|
||||||
|
client = await hass_client()
|
||||||
|
await client.get(url)
|
||||||
|
|
||||||
|
# async_internal_get_tts_audio is called
|
||||||
|
internal_get_tts_audio.assert_called_once_with("test message", "en_US", {})
|
||||||
|
@@ -1,19 +1,6 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_get_tts_audio
|
# name: test_get_tts_audio
|
||||||
list([
|
list([
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
}),
|
|
||||||
'payload': None,
|
|
||||||
'type': 'synthesize-start',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'text': 'Hello world',
|
|
||||||
}),
|
|
||||||
'payload': None,
|
|
||||||
'type': 'synthesize-chunk',
|
|
||||||
}),
|
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'text': 'Hello world',
|
'text': 'Hello world',
|
||||||
@@ -21,29 +8,10 @@
|
|||||||
'payload': None,
|
'payload': None,
|
||||||
'type': 'synthesize',
|
'type': 'synthesize',
|
||||||
}),
|
}),
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
}),
|
|
||||||
'payload': None,
|
|
||||||
'type': 'synthesize-stop',
|
|
||||||
}),
|
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
# name: test_get_tts_audio_different_formats
|
# name: test_get_tts_audio_different_formats
|
||||||
list([
|
list([
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
}),
|
|
||||||
'payload': None,
|
|
||||||
'type': 'synthesize-start',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'text': 'Hello world',
|
|
||||||
}),
|
|
||||||
'payload': None,
|
|
||||||
'type': 'synthesize-chunk',
|
|
||||||
}),
|
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'text': 'Hello world',
|
'text': 'Hello world',
|
||||||
@@ -51,29 +19,10 @@
|
|||||||
'payload': None,
|
'payload': None,
|
||||||
'type': 'synthesize',
|
'type': 'synthesize',
|
||||||
}),
|
}),
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
}),
|
|
||||||
'payload': None,
|
|
||||||
'type': 'synthesize-stop',
|
|
||||||
}),
|
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
# name: test_get_tts_audio_different_formats.1
|
# name: test_get_tts_audio_different_formats.1
|
||||||
list([
|
list([
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
}),
|
|
||||||
'payload': None,
|
|
||||||
'type': 'synthesize-start',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'text': 'Hello world',
|
|
||||||
}),
|
|
||||||
'payload': None,
|
|
||||||
'type': 'synthesize-chunk',
|
|
||||||
}),
|
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'text': 'Hello world',
|
'text': 'Hello world',
|
||||||
@@ -81,12 +30,6 @@
|
|||||||
'payload': None,
|
'payload': None,
|
||||||
'type': 'synthesize',
|
'type': 'synthesize',
|
||||||
}),
|
}),
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
}),
|
|
||||||
'payload': None,
|
|
||||||
'type': 'synthesize-stop',
|
|
||||||
}),
|
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
# name: test_get_tts_audio_streaming
|
# name: test_get_tts_audio_streaming
|
||||||
@@ -128,23 +71,6 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_voice_speaker
|
# name: test_voice_speaker
|
||||||
list([
|
list([
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'voice': dict({
|
|
||||||
'name': 'voice1',
|
|
||||||
'speaker': 'speaker1',
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
'payload': None,
|
|
||||||
'type': 'synthesize-start',
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'text': 'Hello world',
|
|
||||||
}),
|
|
||||||
'payload': None,
|
|
||||||
'type': 'synthesize-chunk',
|
|
||||||
}),
|
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'text': 'Hello world',
|
'text': 'Hello world',
|
||||||
@@ -156,11 +82,5 @@
|
|||||||
'payload': None,
|
'payload': None,
|
||||||
'type': 'synthesize',
|
'type': 'synthesize',
|
||||||
}),
|
}),
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
}),
|
|
||||||
'payload': None,
|
|
||||||
'type': 'synthesize-stop',
|
|
||||||
}),
|
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
|
Reference in New Issue
Block a user