Compare commits

...

2 Commits

Author SHA1 Message Date
Michael Hansen
6cf00e1f34 Add tests 2026-04-02 14:52:31 -05:00
Michael Hansen
d6598366e6 Add SpeechAudioProcessing 2026-04-02 14:28:25 -05:00
7 changed files with 225 additions and 4 deletions

View File

@@ -945,7 +945,10 @@ class PipelineRun:
try:
# Transcribe audio stream
stt_vad: VoiceCommandSegmenter | None = None
if self.audio_settings.is_vad_enabled:
if (
self.audio_settings.is_vad_enabled
and self.stt_provider.audio_processing.requires_external_vad
):
stt_vad = VoiceCommandSegmenter(
silence_seconds=self.audio_settings.silence_seconds
)

View File

@@ -46,7 +46,7 @@ from .legacy import (
async_get_provider,
async_setup_legacy,
)
from .models import SpeechMetadata, SpeechResult
from .models import SpeechAudioProcessing, SpeechMetadata, SpeechResult
__all__ = [
"DOMAIN",
@@ -69,6 +69,12 @@ _LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
DEFAULT_AUDIO_PROCESSING = SpeechAudioProcessing(
requires_external_vad=True,
prefers_auto_gain_enabled=True,
prefers_noise_reduction_enabled=True,
)
@callback
def async_default_engine(hass: HomeAssistant) -> str | None:
@@ -197,6 +203,11 @@ class SpeechToTextEntity(RestoreEntity):
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
@property
def audio_processing(self) -> SpeechAudioProcessing:
"""Return required/preferred input audio processing settings."""
return DEFAULT_AUDIO_PROCESSING
async def async_internal_added_to_hass(self) -> None:
"""Call when the provider entity is added to hass."""
await super().async_internal_added_to_hass()

View File

@@ -26,10 +26,16 @@ from .const import (
AudioFormats,
AudioSampleRates,
)
from .models import SpeechMetadata, SpeechResult
from .models import SpeechAudioProcessing, SpeechMetadata, SpeechResult
_LOGGER = logging.getLogger(__name__)
DEFAULT_AUDIO_PROCESSING = SpeechAudioProcessing(
requires_external_vad=True,
prefers_auto_gain_enabled=True,
prefers_noise_reduction_enabled=True,
)
@callback
def async_default_provider(hass: HomeAssistant) -> str | None:
@@ -143,6 +149,11 @@ class Provider(ABC):
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
@property
def audio_processing(self) -> SpeechAudioProcessing:
"""Return required/preferred input audio processing settings."""
return DEFAULT_AUDIO_PROCESSING
@abstractmethod
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]

View File

@@ -30,3 +30,20 @@ class SpeechResult:
text: str | None
result: SpeechResultState
@dataclass
class SpeechAudioProcessing:
"""Required and preferred input audio processing settings."""
requires_external_vad: bool
"""True if an external voice activity detector (VAD) is required.
If False, the speech-to-text entity must detect the end of speech itself.
"""
prefers_auto_gain_enabled: bool
"""True if input audio should adjust gain automatically for best results."""
prefers_noise_reduction_enabled: bool
"""True if input audio should apply noise reduction for best results."""

View File

@@ -2153,3 +2153,123 @@ async def test_acknowledge_other_agents(
text_to_speech.assert_not_called()
async_converse.assert_called_once()
get_all_targets_in_satellite_area.assert_not_called()
async def test_stt_vad_enabled_based_on_audio_processing(
hass: HomeAssistant,
mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
) -> None:
"""Test that VAD is enabled only when audio_processing.requires_external_vad is True."""
async def audio_data():
yield make_10ms_chunk(b"silence!")
yield make_10ms_chunk(b"speech!")
yield b""
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
# Test with requires_external_vad=True (default)
# VAD should be used
with (
patch(
"homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter"
) as mock_vad,
patch(
"homeassistant.components.stt.async_get_speech_to_text_engine",
return_value=mock_stt_provider,
),
):
# Set the audio_processing on the mock provider
mock_stt_provider._audio_processing = stt.SpeechAudioProcessing(
requires_external_vad=True,
prefers_auto_gain_enabled=True,
prefers_noise_reduction_enabled=True,
)
mock_vad_instance = Mock()
mock_vad.return_value = mock_vad_instance
mock_vad_instance.process.return_value = False # No voice command
pipeline_input = assist_pipeline.pipeline.PipelineInput(
session=mock_chat_session,
device_id=None,
stt_metadata=stt.SpeechMetadata(
language="en-US",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.STT,
end_stage=assist_pipeline.PipelineStage.STT,
event_callback=lambda _: None,
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=True),
),
)
await pipeline_input.validate()
await pipeline_input.execute()
# VAD should be created when requires_external_vad is True
mock_vad.assert_called_once()
assert mock_vad_instance.process.called
# Test with requires_external_vad=False
# VAD should NOT be used
with (
patch(
"homeassistant.components.assist_pipeline.pipeline.VoiceCommandSegmenter"
) as mock_vad,
patch(
"homeassistant.components.stt.async_get_speech_to_text_engine",
return_value=mock_stt_provider,
),
):
# Set the audio_processing on the mock provider
mock_stt_provider._audio_processing = stt.SpeechAudioProcessing(
requires_external_vad=False,
prefers_auto_gain_enabled=True,
prefers_noise_reduction_enabled=True,
)
mock_vad_instance = Mock()
mock_vad.return_value = mock_vad_instance
pipeline_input = assist_pipeline.pipeline.PipelineInput(
session=mock_chat_session,
device_id=None,
stt_metadata=stt.SpeechMetadata(
language="en-US",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.STT,
end_stage=assist_pipeline.PipelineStage.STT,
event_callback=lambda _: None,
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=True),
),
)
await pipeline_input.validate()
await pipeline_input.execute()
# VAD should NOT be created when requires_external_vad is False
mock_vad.assert_not_called()

View File

@@ -7,12 +7,14 @@ from pathlib import Path
from typing import Any
from homeassistant.components.stt import (
DEFAULT_AUDIO_PROCESSING,
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
Provider,
SpeechAudioProcessing,
SpeechMetadata,
SpeechResult,
SpeechResultState,
@@ -34,13 +36,18 @@ class BaseProvider:
fail_process_audio = False
def __init__(
self, *, supported_languages: list[str] | None = None, text: str = "test_result"
self,
*,
supported_languages: list[str] | None = None,
text: str = "test_result",
audio_processing: SpeechAudioProcessing | None = None,
) -> None:
"""Init test provider."""
self._supported_languages = supported_languages or ["de", "de-CH", "en"]
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
self.received: list[bytes] = []
self.text = text
self._audio_processing = audio_processing or DEFAULT_AUDIO_PROCESSING
@property
def supported_languages(self) -> list[str]:
@@ -72,6 +79,11 @@ class BaseProvider:
"""Return a list of supported channels."""
return [AudioChannels.CHANNEL_MONO]
@property
def audio_processing(self) -> SpeechAudioProcessing:
"""Return required/preferred input audio processing settings."""
return self._audio_processing
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:

View File

@@ -15,6 +15,7 @@ from homeassistant.components.stt import (
AudioCodecs,
AudioFormats,
AudioSampleRates,
SpeechAudioProcessing,
async_default_engine,
async_get_provider,
async_get_speech_to_text_engine,
@@ -595,3 +596,49 @@ async def test_get_engine_entity(
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
assert async_get_speech_to_text_engine(hass, "stt.test") is mock_provider_entity
async def test_audio_processing_default(
hass: HomeAssistant, tmp_path: Path, mock_provider: MockSTTProvider
) -> None:
"""Test that the default audio_processing property returns correct values."""
await mock_setup(hass, tmp_path, mock_provider)
engine = async_get_speech_to_text_engine(hass, TEST_DOMAIN)
assert engine is not None
assert engine.audio_processing.requires_external_vad is True
assert engine.audio_processing.prefers_auto_gain_enabled is True
assert engine.audio_processing.prefers_noise_reduction_enabled is True
async def test_audio_processing_entity_default(
hass: HomeAssistant, tmp_path: Path, mock_provider_entity: MockSTTProviderEntity
) -> None:
"""Test that the default audio_processing property on entity returns correct values."""
await mock_config_entry_setup(hass, tmp_path, mock_provider_entity)
engine = async_get_speech_to_text_engine(hass, f"{DOMAIN}.{TEST_DOMAIN}")
assert engine is not None
assert engine.audio_processing.requires_external_vad is True
assert engine.audio_processing.prefers_auto_gain_enabled is True
assert engine.audio_processing.prefers_noise_reduction_enabled is True
async def test_audio_processing_custom(hass: HomeAssistant, tmp_path: Path) -> None:
"""Test that custom audio_processing values are returned correctly."""
custom_processing = SpeechAudioProcessing(
requires_external_vad=False,
prefers_auto_gain_enabled=False,
prefers_noise_reduction_enabled=False,
)
provider = MockSTTProvider(audio_processing=custom_processing)
await mock_setup(hass, tmp_path, provider)
engine = async_get_speech_to_text_engine(hass, TEST_DOMAIN)
assert engine is not None
assert engine.audio_processing.requires_external_vad is False
assert engine.audio_processing.prefers_auto_gain_enabled is False
assert engine.audio_processing.prefers_noise_reduction_enabled is False