mirror of
https://github.com/home-assistant/core.git
synced 2026-04-18 15:39:12 +02:00
Compare commits
2 Commits
flic2
...
synesthesi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6cf00e1f34 | ||
|
|
d6598366e6 |
@@ -945,7 +945,10 @@ class PipelineRun:
|
||||
try:
|
||||
# Transcribe audio stream
|
||||
stt_vad: VoiceCommandSegmenter | None = None
|
||||
if self.audio_settings.is_vad_enabled:
|
||||
if (
|
||||
self.audio_settings.is_vad_enabled
|
||||
and self.stt_provider.audio_processing.requires_external_vad
|
||||
):
|
||||
stt_vad = VoiceCommandSegmenter(
|
||||
silence_seconds=self.audio_settings.silence_seconds
|
||||
)
|
||||
|
||||
@@ -46,7 +46,7 @@ from .legacy import (
|
||||
async_get_provider,
|
||||
async_setup_legacy,
|
||||
)
|
||||
from .models import SpeechMetadata, SpeechResult
|
||||
from .models import SpeechAudioProcessing, SpeechMetadata, SpeechResult
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
@@ -69,6 +69,12 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||
|
||||
DEFAULT_AUDIO_PROCESSING = SpeechAudioProcessing(
|
||||
requires_external_vad=True,
|
||||
prefers_auto_gain_enabled=True,
|
||||
prefers_noise_reduction_enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_default_engine(hass: HomeAssistant) -> str | None:
|
||||
@@ -197,6 +203,11 @@ class SpeechToTextEntity(RestoreEntity):
|
||||
def supported_channels(self) -> list[AudioChannels]:
|
||||
"""Return a list of supported channels."""
|
||||
|
||||
@property
|
||||
def audio_processing(self) -> SpeechAudioProcessing:
|
||||
"""Return required/preferred input audio processing settings."""
|
||||
return DEFAULT_AUDIO_PROCESSING
|
||||
|
||||
async def async_internal_added_to_hass(self) -> None:
|
||||
"""Call when the provider entity is added to hass."""
|
||||
await super().async_internal_added_to_hass()
|
||||
|
||||
@@ -26,10 +26,16 @@ from .const import (
|
||||
AudioFormats,
|
||||
AudioSampleRates,
|
||||
)
|
||||
from .models import SpeechMetadata, SpeechResult
|
||||
from .models import SpeechAudioProcessing, SpeechMetadata, SpeechResult
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_AUDIO_PROCESSING = SpeechAudioProcessing(
|
||||
requires_external_vad=True,
|
||||
prefers_auto_gain_enabled=True,
|
||||
prefers_noise_reduction_enabled=True,
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_default_provider(hass: HomeAssistant) -> str | None:
|
||||
@@ -143,6 +149,11 @@ class Provider(ABC):
|
||||
def supported_channels(self) -> list[AudioChannels]:
|
||||
"""Return a list of supported channels."""
|
||||
|
||||
@property
|
||||
def audio_processing(self) -> SpeechAudioProcessing:
|
||||
"""Return required/preferred input audio processing settings."""
|
||||
return DEFAULT_AUDIO_PROCESSING
|
||||
|
||||
@abstractmethod
|
||||
async def async_process_audio_stream(
|
||||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||
|
||||
@@ -30,3 +30,20 @@ class SpeechResult:
|
||||
|
||||
text: str | None
|
||||
result: SpeechResultState
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeechAudioProcessing:
|
||||
"""Required and preferred input audio processing settings."""
|
||||
|
||||
requires_external_vad: bool
|
||||
"""True if an external voice activity detector (VAD) is required.
|
||||
|
||||
If False, the speech-to-text entity must detect the end of speech itself.
|
||||
"""
|
||||
|
||||
prefers_auto_gain_enabled: bool
|
||||
"""True if input audio should adjust gain automatically for best results."""
|
||||
|
||||
prefers_noise_reduction_enabled: bool
|
||||
"""True if input audio should apply noise reduction for best results."""
|
||||
|
||||
@@ -2153,3 +2153,123 @@ async def test_acknowledge_other_agents(
|
||||
text_to_speech.assert_not_called()
|
||||
async_converse.assert_called_once()
|
||||
get_all_targets_in_satellite_area.assert_not_called()
|
||||
|
||||
|
||||
async def test_stt_vad_enabled_based_on_audio_processing(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSTTProvider,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
) -> None:
|
||||
"""Test that VAD is enabled only when audio_processing.requires_external_vad is True."""
|
||||
|
||||
async def audio_data():
|
||||
yield make_10ms_chunk(b"silence!")
|
||||
yield make_10ms_chunk(b"speech!")
|
||||
yield b""
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
# Test with requires_external_vad=True (default)
|
||||
# VAD should be used
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter"
|
||||
) as mock_vad,
|
||||
patch(
|
||||
"homeassistant.components.stt.async_get_speech_to_text_engine",
|
||||
return_value=mock_stt_provider,
|
||||
),
|
||||
):
|
||||
# Set the audio_processing on the mock provider
|
||||
mock_stt_provider._audio_processing = stt.SpeechAudioProcessing(
|
||||
requires_external_vad=True,
|
||||
prefers_auto_gain_enabled=True,
|
||||
prefers_noise_reduction_enabled=True,
|
||||
)
|
||||
|
||||
mock_vad_instance = Mock()
|
||||
mock_vad.return_value = mock_vad_instance
|
||||
mock_vad_instance.process.return_value = False # No voice command
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="en-US",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.STT,
|
||||
end_stage=assist_pipeline.PipelineStage.STT,
|
||||
event_callback=lambda _: None,
|
||||
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=True),
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
await pipeline_input.execute()
|
||||
|
||||
# VAD should be created when requires_external_vad is True
|
||||
mock_vad.assert_called_once()
|
||||
assert mock_vad_instance.process.called
|
||||
|
||||
# Test with requires_external_vad=False
|
||||
# VAD should NOT be used
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter"
|
||||
) as mock_vad,
|
||||
patch(
|
||||
"homeassistant.components.stt.async_get_speech_to_text_engine",
|
||||
return_value=mock_stt_provider,
|
||||
),
|
||||
):
|
||||
# Set the audio_processing on the mock provider
|
||||
mock_stt_provider._audio_processing = stt.SpeechAudioProcessing(
|
||||
requires_external_vad=False,
|
||||
prefers_auto_gain_enabled=True,
|
||||
prefers_noise_reduction_enabled=True,
|
||||
)
|
||||
|
||||
mock_vad_instance = Mock()
|
||||
mock_vad.return_value = mock_vad_instance
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="en-US",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.STT,
|
||||
end_stage=assist_pipeline.PipelineStage.STT,
|
||||
event_callback=lambda _: None,
|
||||
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=True),
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
await pipeline_input.execute()
|
||||
|
||||
# VAD should NOT be created when requires_external_vad is False
|
||||
mock_vad.assert_not_called()
|
||||
|
||||
@@ -7,12 +7,14 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.stt import (
|
||||
DEFAULT_AUDIO_PROCESSING,
|
||||
AudioBitRates,
|
||||
AudioChannels,
|
||||
AudioCodecs,
|
||||
AudioFormats,
|
||||
AudioSampleRates,
|
||||
Provider,
|
||||
SpeechAudioProcessing,
|
||||
SpeechMetadata,
|
||||
SpeechResult,
|
||||
SpeechResultState,
|
||||
@@ -34,13 +36,18 @@ class BaseProvider:
|
||||
fail_process_audio = False
|
||||
|
||||
def __init__(
|
||||
self, *, supported_languages: list[str] | None = None, text: str = "test_result"
|
||||
self,
|
||||
*,
|
||||
supported_languages: list[str] | None = None,
|
||||
text: str = "test_result",
|
||||
audio_processing: SpeechAudioProcessing | None = None,
|
||||
) -> None:
|
||||
"""Init test provider."""
|
||||
self._supported_languages = supported_languages or ["de", "de-CH", "en"]
|
||||
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
|
||||
self.received: list[bytes] = []
|
||||
self.text = text
|
||||
self._audio_processing = audio_processing or DEFAULT_AUDIO_PROCESSING
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
@@ -72,6 +79,11 @@ class BaseProvider:
|
||||
"""Return a list of supported channels."""
|
||||
return [AudioChannels.CHANNEL_MONO]
|
||||
|
||||
@property
|
||||
def audio_processing(self) -> SpeechAudioProcessing:
|
||||
"""Return required/preferred input audio processing settings."""
|
||||
return self._audio_processing
|
||||
|
||||
async def async_process_audio_stream(
|
||||
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
||||
) -> SpeechResult:
|
||||
|
||||
@@ -15,6 +15,7 @@ from homeassistant.components.stt import (
|
||||
AudioCodecs,
|
||||
AudioFormats,
|
||||
AudioSampleRates,
|
||||
SpeechAudioProcessing,
|
||||
async_default_engine,
|
||||
async_get_provider,
|
||||
async_get_speech_to_text_engine,
|
||||
@@ -595,3 +596,49 @@ async def test_get_engine_entity(
|
||||
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||
|
||||
assert async_get_speech_to_text_engine(hass, "stt.test") is mock_provider_entity
|
||||
|
||||
|
||||
async def test_audio_processing_default(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider: MockSTTProvider
|
||||
) -> None:
|
||||
"""Test that the default audio_processing property returns correct values."""
|
||||
await mock_setup(hass, tmp_path, mock_provider)
|
||||
|
||||
engine = async_get_speech_to_text_engine(hass, TEST_DOMAIN)
|
||||
assert engine is not None
|
||||
|
||||
assert engine.audio_processing.requires_external_vad is True
|
||||
assert engine.audio_processing.prefers_auto_gain_enabled is True
|
||||
assert engine.audio_processing.prefers_noise_reduction_enabled is True
|
||||
|
||||
|
||||
async def test_audio_processing_entity_default(
|
||||
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity
|
||||
) -> None:
|
||||
"""Test that the default audio_processing property on entity returns correct values."""
|
||||
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
|
||||
|
||||
engine = async_get_speech_to_text_engine(hass, f"{DOMAIN}.{TEST_DOMAIN}")
|
||||
assert engine is not None
|
||||
|
||||
assert engine.audio_processing.requires_external_vad is True
|
||||
assert engine.audio_processing.prefers_auto_gain_enabled is True
|
||||
assert engine.audio_processing.prefers_noise_reduction_enabled is True
|
||||
|
||||
|
||||
async def test_audio_processing_custom(hass: HomeAssistant, tmp_path: Path) -> None:
|
||||
"""Test that custom audio_processing values are returned correctly."""
|
||||
custom_processing = SpeechAudioProcessing(
|
||||
requires_external_vad=False,
|
||||
prefers_auto_gain_enabled=False,
|
||||
prefers_noise_reduction_enabled=False,
|
||||
)
|
||||
provider = MockSTTProvider(audio_processing=custom_processing)
|
||||
await mock_setup(hass, tmp_path, provider)
|
||||
|
||||
engine = async_get_speech_to_text_engine(hass, TEST_DOMAIN)
|
||||
assert engine is not None
|
||||
|
||||
assert engine.audio_processing.requires_external_vad is False
|
||||
assert engine.audio_processing.prefers_auto_gain_enabled is False
|
||||
assert engine.audio_processing.prefers_noise_reduction_enabled is False
|
||||
|
||||
Reference in New Issue
Block a user