mirror of
https://github.com/home-assistant/core.git
synced 2025-06-25 01:21:51 +02:00
Add test
This commit is contained in:
@ -69,6 +69,29 @@ TTS_INFO = Info(
|
||||
)
|
||||
]
|
||||
)
|
||||
TTS_STREAMING_INFO = Info(
|
||||
tts=[
|
||||
TtsProgram(
|
||||
name="Test Streaming TTS",
|
||||
description="Test Streaming TTS",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
voices=[
|
||||
TtsVoice(
|
||||
name="Test Voice",
|
||||
description="Test Voice",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
languages=["en-US"],
|
||||
speakers=[TtsVoiceSpeaker(name="Test Speaker")],
|
||||
version=None,
|
||||
)
|
||||
],
|
||||
version=None,
|
||||
supports_synthesize_streaming=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
WAKE_WORD_INFO = Info(
|
||||
wake=[
|
||||
WakeProgram(
|
||||
|
@ -19,6 +19,7 @@ from . import (
|
||||
SATELLITE_INFO,
|
||||
STT_INFO,
|
||||
TTS_INFO,
|
||||
TTS_STREAMING_INFO,
|
||||
WAKE_WORD_INFO,
|
||||
)
|
||||
|
||||
@ -148,6 +149,20 @@ async def init_wyoming_tts(
|
||||
return tts_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_wyoming_streaming_tts(
|
||||
hass: HomeAssistant, tts_config_entry: ConfigEntry
|
||||
) -> ConfigEntry:
|
||||
"""Initialize Wyoming streaming TTS."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=TTS_STREAMING_INFO,
|
||||
):
|
||||
await hass.config_entries.async_setup(tts_config_entry.entry_id)
|
||||
|
||||
return tts_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_wyoming_wake_word(
|
||||
hass: HomeAssistant, wake_word_config_entry: ConfigEntry
|
||||
|
@ -32,6 +32,43 @@
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_get_tts_audio_streaming
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-start',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello ',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Word.',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-chunk',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'text': 'Hello Word.',
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize',
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
}),
|
||||
'payload': None,
|
||||
'type': 'synthesize-stop',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_voice_speaker
|
||||
list([
|
||||
dict({
|
||||
|
@ -8,7 +8,8 @@ import wave
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from wyoming.audio import AudioChunk, AudioStop
|
||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.tts import SynthesizeStopped
|
||||
|
||||
from homeassistant.components import tts, wyoming
|
||||
from homeassistant.core import HomeAssistant
|
||||
@ -43,11 +44,11 @@ async def test_get_tts_audio(
|
||||
hass: HomeAssistant, init_wyoming_tts, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
"""Test get audio."""
|
||||
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
|
||||
assert entity is not None
|
||||
assert not entity.async_supports_streaming_input()
|
||||
|
||||
audio = bytes(100)
|
||||
audio_events = [
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
]
|
||||
|
||||
# Verify audio
|
||||
audio_events = [
|
||||
@ -215,3 +216,49 @@ async def test_voice_speaker(
|
||||
),
|
||||
)
|
||||
assert mock_client.written == snapshot
|
||||
|
||||
|
||||
async def test_get_tts_audio_streaming(
|
||||
hass: HomeAssistant, init_wyoming_streaming_tts, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
"""Test get audio with streaming."""
|
||||
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_streaming_tts")
|
||||
assert entity is not None
|
||||
assert entity.async_supports_streaming_input()
|
||||
|
||||
audio = bytes(100)
|
||||
|
||||
# Verify audio
|
||||
audio_events = [
|
||||
AudioStart(rate=16000, width=2, channels=1).event(),
|
||||
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||
AudioStop().event(),
|
||||
SynthesizeStopped().event(),
|
||||
]
|
||||
|
||||
async def message_gen():
|
||||
yield "Hello "
|
||||
yield "Word."
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||
MockAsyncTcpClient(audio_events),
|
||||
) as mock_client:
|
||||
stream = tts.async_create_stream(
|
||||
hass,
|
||||
"tts.test_streaming_tts",
|
||||
"en-US",
|
||||
options={tts.ATTR_PREFERRED_FORMAT: "wav"},
|
||||
)
|
||||
stream.async_set_message_stream(message_gen())
|
||||
data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||
|
||||
assert data is not None
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
assert wav_file.getframerate() == 16000
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.getnframes() == 0 # streaming
|
||||
assert data[44:] == audio # WAV header is 44 bytes
|
||||
|
||||
assert mock_client.written == snapshot
|
||||
|
Reference in New Issue
Block a user