From 1a9d1a96494d57f18717bb07e8e181353691967c Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Mon, 11 Aug 2025 11:47:29 -0500 Subject: [PATCH] Handle non-streaming TTS case correctly (#150218) --- homeassistant/components/tts/__init__.py | 7 +- homeassistant/components/tts/entity.py | 12 +++ tests/components/tts/test_entity.py | 28 +++++++ tests/components/tts/test_init.py | 31 +++++++ .../wyoming/snapshots/test_tts.ambr | 80 ------------------- 5 files changed, 77 insertions(+), 81 deletions(-) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index cf9099448df..629332d9d64 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -976,11 +976,15 @@ class SpeechManager: if engine_instance.name is None or engine_instance.name is UNDEFINED: raise HomeAssistantError("TTS engine name is not set.") - if isinstance(engine_instance, Provider): + if isinstance(engine_instance, Provider) or ( + not engine_instance.async_supports_streaming_input() + ): + # Non-streaming if isinstance(message_or_stream, str): message = message_or_stream else: message = "".join([chunk async for chunk in message_or_stream]) + extension, data = await engine_instance.async_internal_get_tts_audio( message, language, options ) @@ -996,6 +1000,7 @@ class SpeechManager: data_gen = make_data_generator(data) else: + # Streaming if isinstance(message_or_stream, str): async def gen_stream() -> AsyncGenerator[str]: diff --git a/homeassistant/components/tts/entity.py b/homeassistant/components/tts/entity.py index aea5be6d0da..77abaa26bab 100644 --- a/homeassistant/components/tts/entity.py +++ b/homeassistant/components/tts/entity.py @@ -191,6 +191,18 @@ class TextToSpeechEntity(RestoreEntity, cached_properties=CACHED_PROPERTIES_WITH """Load tts audio file from the engine.""" raise NotImplementedError + @final + async def async_internal_get_tts_audio( + self, message: str, language: str, options: dict[str, Any] + ) -> TtsAudioType: + """Load tts audio file from the engine and update state. + + Return a tuple of file extension and data as bytes. + """ + self.__last_tts_loaded = dt_util.utcnow().isoformat() + self.async_write_ha_state() + return await self.async_get_tts_audio(message, language, options=options) + async def async_get_tts_audio( self, message: str, language: str, options: dict[str, Any] ) -> TtsAudioType: diff --git a/tests/components/tts/test_entity.py b/tests/components/tts/test_entity.py index 8648ca95e93..308d3bb0fca 100644 --- a/tests/components/tts/test_entity.py +++ b/tests/components/tts/test_entity.py @@ -175,3 +175,31 @@ def test_streaming_supported() -> None: sync_non_streaming_entity = SyncNonStreamingEntity() assert sync_non_streaming_entity.async_supports_streaming_input() is False + + +async def test_internal_get_tts_audio_writes_state( + hass: HomeAssistant, + mock_tts_entity: MockTTSEntity, +) -> None: + """Test that only async_internal_get_tts_audio updates and writes the state.""" + + entity_id = f"{tts.DOMAIN}.{TEST_DOMAIN}" + + config_entry = await mock_config_entry_setup(hass, mock_tts_entity) + assert config_entry.state is ConfigEntryState.LOADED + state1 = hass.states.get(entity_id) + assert state1 is not None + + # State should *not* change with external method + await mock_tts_entity.async_get_tts_audio("test message", hass.config.language, {}) + state2 = hass.states.get(entity_id) + assert state2 is not None + assert state1.state == state2.state + + # State *should* change with internal method + await mock_tts_entity.async_internal_get_tts_audio( + "test message", hass.config.language, {} + ) + state3 = hass.states.get(entity_id) + assert state3 is not None + assert state1.state != state3.state diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index db42da5de0e..be155aae182 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -2032,3 +2032,34 @@ async def test_tts_cache() -> None: assert await consume_mid_data_task == b"012" with pytest.raises(ValueError): assert await consume_pre_data_loaded_task == b"012" + + +async def test_async_internal_get_tts_audio_called( + hass: HomeAssistant, + mock_tts_entity: MockTTSEntity, + hass_client: ClientSessionGenerator, +) -> None: + """Test that non-streaming entity has its async_internal_get_tts_audio method called.""" + + await mock_config_entry_setup(hass, mock_tts_entity) + + # Non-streaming + assert mock_tts_entity.async_supports_streaming_input() is False + + with patch( + "homeassistant.components.tts.entity.TextToSpeechEntity.async_internal_get_tts_audio" + ) as internal_get_tts_audio: + media_source_id = tts.generate_media_source_id( + hass, + "test message", + "tts.test", + "en_US", + cache=None, + ) + + url = await get_media_source_url(hass, media_source_id) + client = await hass_client() + await client.get(url) + + # async_internal_get_tts_audio is called + internal_get_tts_audio.assert_called_once_with("test message", "en_US", {}) diff --git a/tests/components/wyoming/snapshots/test_tts.ambr b/tests/components/wyoming/snapshots/test_tts.ambr index 67c9b24160c..53cc02eaacf 100644 --- a/tests/components/wyoming/snapshots/test_tts.ambr +++ b/tests/components/wyoming/snapshots/test_tts.ambr @@ -1,19 +1,6 @@ # serializer version: 1 # name: test_get_tts_audio list([ - dict({ - 'data': dict({ - }), - 'payload': None, - 'type': 'synthesize-start', - }), - dict({ - 'data': dict({ - 'text': 'Hello world', - }), - 'payload': None, - 'type': 'synthesize-chunk', - }), dict({ 'data': dict({ 'text': 'Hello world', @@ -21,29 +8,10 @@ 'payload': None, 'type': 'synthesize', }), - dict({ - 'data': dict({ - }), - 'payload': None, - 'type': 'synthesize-stop', - }), ]) # --- # name: test_get_tts_audio_different_formats list([ - dict({ - 'data': dict({ - }), - 'payload': None, - 'type': 'synthesize-start', - }), - dict({ - 'data': dict({ - 'text': 'Hello world', - }), - 'payload': None, - 'type': 'synthesize-chunk', - }), dict({ 'data': dict({ 'text': 'Hello world', @@ -51,29 +19,10 @@ 'payload': None, 'type': 'synthesize', }), - dict({ - 'data': dict({ - }), - 'payload': None, - 'type': 'synthesize-stop', - }), ]) # --- # name: test_get_tts_audio_different_formats.1 list([ - dict({ - 'data': dict({ - }), - 'payload': None, - 'type': 'synthesize-start', - }), - dict({ - 'data': dict({ - 'text': 'Hello world', - }), - 'payload': None, - 'type': 'synthesize-chunk', - }), dict({ 'data': dict({ 'text': 'Hello world', @@ -81,12 +30,6 @@ 'payload': None, 'type': 'synthesize', }), - dict({ - 'data': dict({ - }), - 'payload': None, - 'type': 'synthesize-stop', - }), ]) # --- # name: test_get_tts_audio_streaming @@ -128,23 +71,6 @@ # --- # name: test_voice_speaker list([ - dict({ - 'data': dict({ - 'voice': dict({ - 'name': 'voice1', - 'speaker': 'speaker1', - }), - }), - 'payload': None, - 'type': 'synthesize-start', - }), - dict({ - 'data': dict({ - 'text': 'Hello world', - }), - 'payload': None, - 'type': 'synthesize-chunk', - }), dict({ 'data': dict({ 'text': 'Hello world', @@ -156,11 +82,5 @@ 'payload': None, 'type': 'synthesize', }), - dict({ - 'data': dict({ - }), - 'payload': None, - 'type': 'synthesize-stop', - }), ]) # ---