Switch from WebRTC to microVAD (#122861)

* Switch WebRTC to microVAD

* Remove webrtc-noise-gain from licenses
This commit is contained in:
Michael Hansen
2024-07-31 02:42:45 -05:00
committed by GitHub
parent beb2ef121e
commit 7f4dabf546
15 changed files with 320 additions and 347 deletions

View File

@ -0,0 +1,82 @@
"""Audio enhancement for Assist."""
from abc import ABC, abstractmethod
from dataclasses import dataclass
import logging
from pymicro_vad import MicroVad
_LOGGER = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class EnhancedAudioChunk:
"""Enhanced audio chunk and metadata."""
audio: bytes
"""Raw PCM audio @ 16Khz with 16-bit mono samples"""
timestamp_ms: int
"""Timestamp relative to start of audio stream (milliseconds)"""
is_speech: bool | None
"""True if audio chunk likely contains speech, False if not, None if unknown"""
class AudioEnhancer(ABC):
"""Base class for audio enhancement."""
def __init__(
self, auto_gain: int, noise_suppression: int, is_vad_enabled: bool
) -> None:
"""Initialize audio enhancer."""
self.auto_gain = auto_gain
self.noise_suppression = noise_suppression
self.is_vad_enabled = is_vad_enabled
@abstractmethod
def enhance_chunk(self, audio: bytes, timestamp_ms: int) -> EnhancedAudioChunk:
"""Enhance chunk of PCM audio @ 16Khz with 16-bit mono samples."""
@property
@abstractmethod
def samples_per_chunk(self) -> int | None:
"""Return number of samples per chunk or None if chunking isn't required."""
class MicroVadEnhancer(AudioEnhancer):
"""Audio enhancer that just runs microVAD."""
def __init__(
self, auto_gain: int, noise_suppression: int, is_vad_enabled: bool
) -> None:
"""Initialize audio enhancer."""
super().__init__(auto_gain, noise_suppression, is_vad_enabled)
self.vad: MicroVad | None = None
self.threshold = 0.5
if self.is_vad_enabled:
self.vad = MicroVad()
_LOGGER.debug("Initialized microVAD with threshold=%s", self.threshold)
def enhance_chunk(self, audio: bytes, timestamp_ms: int) -> EnhancedAudioChunk:
"""Enhance chunk of PCM audio @ 16Khz with 16-bit mono samples."""
is_speech: bool | None = None
if self.vad is not None:
# Run VAD
speech_prob = self.vad.Process10ms(audio)
is_speech = speech_prob > self.threshold
return EnhancedAudioChunk(
audio=audio, timestamp_ms=timestamp_ms, is_speech=is_speech
)
@property
def samples_per_chunk(self) -> int | None:
"""Return number of samples per chunk or None if chunking isn't required."""
if self.is_vad_enabled:
return 160 # 10ms
return None

View File

@ -15,3 +15,8 @@ DATA_LAST_WAKE_UP = f"{DOMAIN}.last_wake_up"
WAKE_WORD_COOLDOWN = 2 # seconds
EVENT_RECORDING = f"{DOMAIN}_recording"
SAMPLE_RATE = 16000 # hertz
SAMPLE_WIDTH = 2 # bytes
SAMPLE_CHANNELS = 1 # mono
SAMPLES_PER_CHUNK = 240 # 20 ms @ 16Khz

View File

@ -6,5 +6,5 @@
"documentation": "https://www.home-assistant.io/integrations/assist_pipeline",
"iot_class": "local_push",
"quality_scale": "internal",
"requirements": ["webrtc-noise-gain==1.2.3"]
"requirements": ["pymicro-vad==1.0.0"]
}

View File

@ -13,14 +13,11 @@ from pathlib import Path
from queue import Empty, Queue
from threading import Thread
import time
from typing import TYPE_CHECKING, Any, Final, Literal, cast
from typing import Any, Literal, cast
import wave
import voluptuous as vol
if TYPE_CHECKING:
from webrtc_noise_gain import AudioProcessor
from homeassistant.components import (
conversation,
media_source,
@ -52,12 +49,17 @@ from homeassistant.util import (
)
from homeassistant.util.limited_size_dict import LimitedSizeDict
from .audio_enhancer import AudioEnhancer, EnhancedAudioChunk, MicroVadEnhancer
from .const import (
CONF_DEBUG_RECORDING_DIR,
DATA_CONFIG,
DATA_LAST_WAKE_UP,
DATA_MIGRATIONS,
DOMAIN,
SAMPLE_CHANNELS,
SAMPLE_RATE,
SAMPLE_WIDTH,
SAMPLES_PER_CHUNK,
WAKE_WORD_COOLDOWN,
)
from .error import (
@ -111,9 +113,6 @@ STORED_PIPELINE_RUNS = 10
SAVE_DELAY = 10
AUDIO_PROCESSOR_SAMPLES: Final = 160 # 10 ms @ 16 Khz
AUDIO_PROCESSOR_BYTES: Final = AUDIO_PROCESSOR_SAMPLES * 2 # 16-bit samples
@callback
def _async_resolve_default_pipeline_settings(
@ -503,8 +502,8 @@ class AudioSettings:
is_vad_enabled: bool = True
"""True if VAD is used to determine the end of the voice command."""
is_chunking_enabled: bool = True
"""True if audio is automatically split into 10 ms chunks (required for VAD, etc.)"""
samples_per_chunk: int | None = None
"""Number of samples that will be in each audio chunk (None for no chunking)."""
def __post_init__(self) -> None:
"""Verify settings post-initialization."""
@ -514,9 +513,6 @@ class AudioSettings:
if (self.auto_gain_dbfs < 0) or (self.auto_gain_dbfs > 31):
raise ValueError("auto_gain_dbfs must be in [0, 31]")
if self.needs_processor and (not self.is_chunking_enabled):
raise ValueError("Chunking must be enabled for audio processing")
@property
def needs_processor(self) -> bool:
"""True if an audio processor is needed."""
@ -526,19 +522,10 @@ class AudioSettings:
or (self.auto_gain_dbfs > 0)
)
@dataclass(frozen=True, slots=True)
class ProcessedAudioChunk:
"""Processed audio chunk and metadata."""
audio: bytes
"""Raw PCM audio @ 16Khz with 16-bit mono samples"""
timestamp_ms: int
"""Timestamp relative to start of audio stream (milliseconds)"""
is_speech: bool | None
"""True if audio chunk likely contains speech, False if not, None if unknown"""
@property
def is_chunking_enabled(self) -> bool:
"""True if chunk size is set."""
return self.samples_per_chunk is not None
@dataclass
@ -573,10 +560,10 @@ class PipelineRun:
debug_recording_queue: Queue[str | bytes | None] | None = None
"""Queue to communicate with debug recording thread"""
audio_processor: AudioProcessor | None = None
audio_enhancer: AudioEnhancer | None = None
"""VAD/noise suppression/auto gain"""
audio_processor_buffer: AudioBuffer = field(init=False, repr=False)
audio_chunking_buffer: AudioBuffer | None = None
"""Buffer used when splitting audio into chunks for audio processing"""
_device_id: str | None = None
@ -601,19 +588,16 @@ class PipelineRun:
pipeline_data.pipeline_runs.add_run(self)
# Initialize with audio settings
self.audio_processor_buffer = AudioBuffer(AUDIO_PROCESSOR_BYTES)
if self.audio_settings.needs_processor:
# Delay import of webrtc so HA start up is not crashing
# on older architectures (armhf).
#
# pylint: disable=import-outside-toplevel
from webrtc_noise_gain import AudioProcessor
self.audio_processor = AudioProcessor(
if self.audio_settings.needs_processor and (self.audio_enhancer is None):
# Default audio enhancer
self.audio_enhancer = MicroVadEnhancer(
self.audio_settings.auto_gain_dbfs,
self.audio_settings.noise_suppression_level,
self.audio_settings.is_vad_enabled,
)
self.audio_chunking_buffer = AudioBuffer(self.samples_per_chunk * SAMPLE_WIDTH)
def __eq__(self, other: object) -> bool:
"""Compare pipeline runs by id."""
if isinstance(other, PipelineRun):
@ -621,6 +605,14 @@ class PipelineRun:
return False
@property
def samples_per_chunk(self) -> int:
"""Return number of samples expected in each audio chunk."""
if self.audio_enhancer is not None:
return self.audio_enhancer.samples_per_chunk or SAMPLES_PER_CHUNK
return self.audio_settings.samples_per_chunk or SAMPLES_PER_CHUNK
@callback
def process_event(self, event: PipelineEvent) -> None:
"""Log an event and call listener."""
@ -688,8 +680,8 @@ class PipelineRun:
async def wake_word_detection(
self,
stream: AsyncIterable[ProcessedAudioChunk],
audio_chunks_for_stt: list[ProcessedAudioChunk],
stream: AsyncIterable[EnhancedAudioChunk],
audio_chunks_for_stt: list[EnhancedAudioChunk],
) -> wake_word.DetectionResult | None:
"""Run wake-word-detection portion of pipeline. Returns detection result."""
metadata_dict = asdict(
@ -732,10 +724,11 @@ class PipelineRun:
# Audio chunk buffer. This audio will be forwarded to speech-to-text
# after wake-word-detection.
num_audio_chunks_to_buffer = int(
(wake_word_settings.audio_seconds_to_buffer * 16000)
/ AUDIO_PROCESSOR_SAMPLES
(wake_word_settings.audio_seconds_to_buffer * SAMPLE_RATE)
/ self.samples_per_chunk
)
stt_audio_buffer: deque[ProcessedAudioChunk] | None = None
stt_audio_buffer: deque[EnhancedAudioChunk] | None = None
if num_audio_chunks_to_buffer > 0:
stt_audio_buffer = deque(maxlen=num_audio_chunks_to_buffer)
@ -797,7 +790,7 @@ class PipelineRun:
# speech-to-text so the user does not have to pause before
# speaking the voice command.
audio_chunks_for_stt.extend(
ProcessedAudioChunk(
EnhancedAudioChunk(
audio=chunk_ts[0], timestamp_ms=chunk_ts[1], is_speech=False
)
for chunk_ts in result.queued_audio
@ -819,18 +812,17 @@ class PipelineRun:
async def _wake_word_audio_stream(
self,
audio_stream: AsyncIterable[ProcessedAudioChunk],
stt_audio_buffer: deque[ProcessedAudioChunk] | None,
audio_stream: AsyncIterable[EnhancedAudioChunk],
stt_audio_buffer: deque[EnhancedAudioChunk] | None,
wake_word_vad: VoiceActivityTimeout | None,
sample_rate: int = 16000,
sample_width: int = 2,
sample_rate: int = SAMPLE_RATE,
sample_width: int = SAMPLE_WIDTH,
) -> AsyncIterable[tuple[bytes, int]]:
"""Yield audio chunks with timestamps (milliseconds since start of stream).
Adds audio to a ring buffer that will be forwarded to speech-to-text after
detection. Times out if VAD detects enough silence.
"""
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
async for chunk in audio_stream:
if self.abort_wake_word_detection:
raise WakeWordDetectionAborted
@ -845,6 +837,7 @@ class PipelineRun:
stt_audio_buffer.append(chunk)
if wake_word_vad is not None:
chunk_seconds = (len(chunk.audio) // sample_width) / sample_rate
if not wake_word_vad.process(chunk_seconds, chunk.is_speech):
raise WakeWordTimeoutError(
code="wake-word-timeout", message="Wake word was not detected"
@ -881,7 +874,7 @@ class PipelineRun:
async def speech_to_text(
self,
metadata: stt.SpeechMetadata,
stream: AsyncIterable[ProcessedAudioChunk],
stream: AsyncIterable[EnhancedAudioChunk],
) -> str:
"""Run speech-to-text portion of pipeline. Returns the spoken text."""
# Create a background task to prepare the conversation agent
@ -957,18 +950,18 @@ class PipelineRun:
async def _speech_to_text_stream(
self,
audio_stream: AsyncIterable[ProcessedAudioChunk],
audio_stream: AsyncIterable[EnhancedAudioChunk],
stt_vad: VoiceCommandSegmenter | None,
sample_rate: int = 16000,
sample_width: int = 2,
sample_rate: int = SAMPLE_RATE,
sample_width: int = SAMPLE_WIDTH,
) -> AsyncGenerator[bytes]:
"""Yield audio chunks until VAD detects silence or speech-to-text completes."""
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
sent_vad_start = False
async for chunk in audio_stream:
self._capture_chunk(chunk.audio)
if stt_vad is not None:
chunk_seconds = (len(chunk.audio) // sample_width) / sample_rate
if not stt_vad.process(chunk_seconds, chunk.is_speech):
# Silence detected at the end of voice command
self.process_event(
@ -1072,8 +1065,8 @@ class PipelineRun:
tts_options[tts.ATTR_PREFERRED_FORMAT] = self.tts_audio_output
if self.tts_audio_output == "wav":
# 16 Khz, 16-bit mono
tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = 16000
tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = 1
tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = SAMPLE_RATE
tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = SAMPLE_CHANNELS
try:
options_supported = await tts.async_support_options(
@ -1220,12 +1213,15 @@ class PipelineRun:
async def process_volume_only(
self,
audio_stream: AsyncIterable[bytes],
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncGenerator[ProcessedAudioChunk]:
sample_rate: int = SAMPLE_RATE,
sample_width: int = SAMPLE_WIDTH,
) -> AsyncGenerator[EnhancedAudioChunk]:
"""Apply volume transformation only (no VAD/audio enhancements) with optional chunking."""
assert self.audio_chunking_buffer is not None
bytes_per_chunk = self.samples_per_chunk * sample_width
ms_per_sample = sample_rate // 1000
ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample
ms_per_chunk = self.samples_per_chunk // ms_per_sample
timestamp_ms = 0
async for chunk in audio_stream:
@ -1233,19 +1229,18 @@ class PipelineRun:
chunk = _multiply_volume(chunk, self.audio_settings.volume_multiplier)
if self.audio_settings.is_chunking_enabled:
# 10 ms chunking
for chunk_10ms in chunk_samples(
chunk, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer
for sub_chunk in chunk_samples(
chunk, bytes_per_chunk, self.audio_chunking_buffer
):
yield ProcessedAudioChunk(
audio=chunk_10ms,
yield EnhancedAudioChunk(
audio=sub_chunk,
timestamp_ms=timestamp_ms,
is_speech=None, # no VAD
)
timestamp_ms += ms_per_chunk
else:
# No chunking
yield ProcessedAudioChunk(
yield EnhancedAudioChunk(
audio=chunk,
timestamp_ms=timestamp_ms,
is_speech=None, # no VAD
@ -1255,14 +1250,19 @@ class PipelineRun:
async def process_enhance_audio(
self,
audio_stream: AsyncIterable[bytes],
sample_rate: int = 16000,
sample_width: int = 2,
) -> AsyncGenerator[ProcessedAudioChunk]:
sample_rate: int = SAMPLE_RATE,
sample_width: int = SAMPLE_WIDTH,
) -> AsyncGenerator[EnhancedAudioChunk]:
"""Split audio into 10 ms chunks and apply VAD/noise suppression/auto gain/volume transformation."""
assert self.audio_processor is not None
assert self.audio_enhancer is not None
assert self.audio_enhancer.samples_per_chunk is not None
assert self.audio_chunking_buffer is not None
bytes_per_chunk = self.audio_enhancer.samples_per_chunk * sample_width
ms_per_sample = sample_rate // 1000
ms_per_chunk = (AUDIO_PROCESSOR_SAMPLES // sample_width) // ms_per_sample
ms_per_chunk = (
self.audio_enhancer.samples_per_chunk // sample_width
) // ms_per_sample
timestamp_ms = 0
async for dirty_samples in audio_stream:
@ -1272,17 +1272,11 @@ class PipelineRun:
dirty_samples, self.audio_settings.volume_multiplier
)
# Split into 10ms chunks for audio enhancements/VAD
for dirty_10ms_chunk in chunk_samples(
dirty_samples, AUDIO_PROCESSOR_BYTES, self.audio_processor_buffer
# Split into chunks for audio enhancements/VAD
for dirty_chunk in chunk_samples(
dirty_samples, bytes_per_chunk, self.audio_chunking_buffer
):
ap_result = self.audio_processor.Process10ms(dirty_10ms_chunk)
yield ProcessedAudioChunk(
audio=ap_result.audio,
timestamp_ms=timestamp_ms,
is_speech=ap_result.is_speech,
)
yield self.audio_enhancer.enhance_chunk(dirty_chunk, timestamp_ms)
timestamp_ms += ms_per_chunk
@ -1323,9 +1317,9 @@ def _pipeline_debug_recording_thread_proc(
wav_path = run_recording_dir / f"{message}.wav"
wav_writer = wave.open(str(wav_path), "wb")
wav_writer.setframerate(16000)
wav_writer.setsampwidth(2)
wav_writer.setnchannels(1)
wav_writer.setframerate(SAMPLE_RATE)
wav_writer.setsampwidth(SAMPLE_WIDTH)
wav_writer.setnchannels(SAMPLE_CHANNELS)
elif isinstance(message, bytes):
# Chunk of 16-bit mono audio at 16Khz
if wav_writer is not None:
@ -1368,8 +1362,8 @@ class PipelineInput:
"""Run pipeline."""
self.run.start(device_id=self.device_id)
current_stage: PipelineStage | None = self.run.start_stage
stt_audio_buffer: list[ProcessedAudioChunk] = []
stt_processed_stream: AsyncIterable[ProcessedAudioChunk] | None = None
stt_audio_buffer: list[EnhancedAudioChunk] = []
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
if self.stt_stream is not None:
if self.run.audio_settings.needs_processor:
@ -1423,7 +1417,7 @@ class PipelineInput:
# Send audio in the buffer first to speech-to-text, then move on to stt_stream.
# This is basically an async itertools.chain.
async def buffer_then_audio_stream() -> (
AsyncGenerator[ProcessedAudioChunk]
AsyncGenerator[EnhancedAudioChunk]
):
# Buffered audio
for chunk in stt_audio_buffer:

View File

@ -2,12 +2,11 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from enum import StrEnum
import logging
from typing import Final, cast
from typing import Final
_LOGGER = logging.getLogger(__name__)
@ -35,44 +34,6 @@ class VadSensitivity(StrEnum):
return 1.0
class VoiceActivityDetector(ABC):
"""Base class for voice activity detectors (VAD)."""
@abstractmethod
def is_speech(self, chunk: bytes) -> bool:
"""Return True if audio chunk contains speech."""
@property
@abstractmethod
def samples_per_chunk(self) -> int | None:
"""Return number of samples per chunk or None if chunking is not required."""
class WebRtcVad(VoiceActivityDetector):
"""Voice activity detector based on webrtc."""
def __init__(self) -> None:
"""Initialize webrtcvad."""
# Delay import of webrtc so HA start up is not crashing
# on older architectures (armhf).
#
# pylint: disable=import-outside-toplevel
from webrtc_noise_gain import AudioProcessor
# Just VAD: no noise suppression or auto gain
self._audio_processor = AudioProcessor(0, 0)
def is_speech(self, chunk: bytes) -> bool:
"""Return True if audio chunk contains speech."""
result = self._audio_processor.Process10ms(chunk)
return cast(bool, result.is_speech)
@property
def samples_per_chunk(self) -> int | None:
"""Return 10 ms."""
return int(0.01 * _SAMPLE_RATE) # 10 ms
class AudioBuffer:
"""Fixed-sized audio buffer with variable internal length."""
@ -176,29 +137,38 @@ class VoiceCommandSegmenter:
if self._speech_seconds_left <= 0:
# Inside voice command
self.in_command = True
self._silence_seconds_left = self.silence_seconds
_LOGGER.debug("Voice command started")
else:
# Reset if enough silence
self._reset_seconds_left -= chunk_seconds
if self._reset_seconds_left <= 0:
self._speech_seconds_left = self.speech_seconds
self._reset_seconds_left = self.reset_seconds
elif not is_speech:
# Silence in command
self._reset_seconds_left = self.reset_seconds
self._silence_seconds_left -= chunk_seconds
if self._silence_seconds_left <= 0:
# Command finished successfully
self.reset()
_LOGGER.debug("Voice command finished")
return False
else:
# Reset if enough speech
# Speech in command.
# Reset silence counter if enough speech.
self._reset_seconds_left -= chunk_seconds
if self._reset_seconds_left <= 0:
self._silence_seconds_left = self.silence_seconds
self._reset_seconds_left = self.reset_seconds
return True
def process_with_vad(
self,
chunk: bytes,
vad: VoiceActivityDetector,
vad_samples_per_chunk: int | None,
vad_is_speech: Callable[[bytes], bool],
leftover_chunk_buffer: AudioBuffer | None,
) -> bool:
"""Process an audio chunk using an external VAD.
@ -207,20 +177,20 @@ class VoiceCommandSegmenter:
Returns False when voice command is finished.
"""
if vad.samples_per_chunk is None:
if vad_samples_per_chunk is None:
# No chunking
chunk_seconds = (len(chunk) // _SAMPLE_WIDTH) / _SAMPLE_RATE
is_speech = vad.is_speech(chunk)
is_speech = vad_is_speech(chunk)
return self.process(chunk_seconds, is_speech)
if leftover_chunk_buffer is None:
raise ValueError("leftover_chunk_buffer is required when vad uses chunking")
# With chunking
seconds_per_chunk = vad.samples_per_chunk / _SAMPLE_RATE
bytes_per_chunk = vad.samples_per_chunk * _SAMPLE_WIDTH
seconds_per_chunk = vad_samples_per_chunk / _SAMPLE_RATE
bytes_per_chunk = vad_samples_per_chunk * _SAMPLE_WIDTH
for vad_chunk in chunk_samples(chunk, bytes_per_chunk, leftover_chunk_buffer):
is_speech = vad.is_speech(vad_chunk)
is_speech = vad_is_speech(vad_chunk)
if not self.process(seconds_per_chunk, is_speech):
return False

View File

@ -24,6 +24,9 @@ from .const import (
DEFAULT_WAKE_WORD_TIMEOUT,
DOMAIN,
EVENT_RECORDING,
SAMPLE_CHANNELS,
SAMPLE_RATE,
SAMPLE_WIDTH,
)
from .error import PipelineNotFound
from .pipeline import (
@ -92,7 +95,6 @@ def async_register_websocket_api(hass: HomeAssistant) -> None:
vol.Optional("volume_multiplier"): float,
# Advanced use cases/testing
vol.Optional("no_vad"): bool,
vol.Optional("no_chunking"): bool,
}
},
extra=vol.ALLOW_EXTRA,
@ -170,9 +172,14 @@ async def websocket_run(
# Yield until we receive an empty chunk
while chunk := await audio_queue.get():
if incoming_sample_rate != 16000:
if incoming_sample_rate != SAMPLE_RATE:
chunk, state = audioop.ratecv(
chunk, 2, 1, incoming_sample_rate, 16000, state
chunk,
SAMPLE_WIDTH,
SAMPLE_CHANNELS,
incoming_sample_rate,
SAMPLE_RATE,
state,
)
yield chunk
@ -206,7 +213,6 @@ async def websocket_run(
auto_gain_dbfs=msg_input.get("auto_gain_dbfs", 0),
volume_multiplier=msg_input.get("volume_multiplier", 1.0),
is_vad_enabled=not msg_input.get("no_vad", False),
is_chunking_enabled=not msg_input.get("no_chunking", False),
)
elif start_stage == PipelineStage.INTENT:
# Input to conversation agent
@ -424,9 +430,9 @@ def websocket_list_languages(
connection.send_result(
msg["id"],
{
"languages": sorted(pipeline_languages)
if pipeline_languages
else pipeline_languages
"languages": (
sorted(pipeline_languages) if pipeline_languages else pipeline_languages
)
},
)

View File

@ -31,12 +31,14 @@ from homeassistant.components.assist_pipeline import (
async_pipeline_from_audio_stream,
select as pipeline_select,
)
from homeassistant.components.assist_pipeline.audio_enhancer import (
AudioEnhancer,
MicroVadEnhancer,
)
from homeassistant.components.assist_pipeline.vad import (
AudioBuffer,
VadSensitivity,
VoiceActivityDetector,
VoiceCommandSegmenter,
WebRtcVad,
)
from homeassistant.const import __version__
from homeassistant.core import Context, HomeAssistant
@ -233,13 +235,13 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
try:
# Wait for speech before starting pipeline
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
vad = WebRtcVad()
audio_enhancer = MicroVadEnhancer(0, 0, True)
chunk_buffer: deque[bytes] = deque(
maxlen=self.buffered_chunks_before_speech,
)
speech_detected = await self._wait_for_speech(
segmenter,
vad,
audio_enhancer,
chunk_buffer,
)
if not speech_detected:
@ -253,7 +255,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
try:
async for chunk in self._segment_audio(
segmenter,
vad,
audio_enhancer,
chunk_buffer,
):
yield chunk
@ -317,7 +319,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async def _wait_for_speech(
self,
segmenter: VoiceCommandSegmenter,
vad: VoiceActivityDetector,
audio_enhancer: AudioEnhancer,
chunk_buffer: MutableSequence[bytes],
):
"""Buffer audio chunks until speech is detected.
@ -329,13 +331,18 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
assert vad.samples_per_chunk is not None
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH)
assert audio_enhancer.samples_per_chunk is not None
vad_buffer = AudioBuffer(audio_enhancer.samples_per_chunk * WIDTH)
while chunk:
chunk_buffer.append(chunk)
segmenter.process_with_vad(chunk, vad, vad_buffer)
segmenter.process_with_vad(
chunk,
audio_enhancer.samples_per_chunk,
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
vad_buffer,
)
if segmenter.in_command:
# Buffer until command starts
if len(vad_buffer) > 0:
@ -351,7 +358,7 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async def _segment_audio(
self,
segmenter: VoiceCommandSegmenter,
vad: VoiceActivityDetector,
audio_enhancer: AudioEnhancer,
chunk_buffer: Sequence[bytes],
) -> AsyncIterable[bytes]:
"""Yield audio chunks until voice command has finished."""
@ -364,11 +371,16 @@ class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
async with asyncio.timeout(self.audio_timeout):
chunk = await self._audio_queue.get()
assert vad.samples_per_chunk is not None
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH)
assert audio_enhancer.samples_per_chunk is not None
vad_buffer = AudioBuffer(audio_enhancer.samples_per_chunk * WIDTH)
while chunk:
if not segmenter.process_with_vad(chunk, vad, vad_buffer):
if not segmenter.process_with_vad(
chunk,
audio_enhancer.samples_per_chunk,
lambda x: audio_enhancer.enhance_chunk(x, 0).is_speech is True,
vad_buffer,
):
# Voice command is finished
break

View File

@ -45,6 +45,7 @@ Pillow==10.4.0
pip>=21.3.1
psutil-home-assistant==0.0.1
PyJWT==2.8.0
pymicro-vad==1.0.0
PyNaCl==1.5.0
pyOpenSSL==24.2.1
pyserial==3.5
@ -60,7 +61,6 @@ urllib3>=1.26.5,<2
voluptuous-openapi==0.0.5
voluptuous-serialize==2.6.0
voluptuous==0.15.2
webrtc-noise-gain==1.2.3
yarl==1.9.4
zeroconf==0.132.2

View File

@ -2007,6 +2007,9 @@ pymelcloud==2.5.9
# homeassistant.components.meteoclimatic
pymeteoclimatic==0.1.0
# homeassistant.components.assist_pipeline
pymicro-vad==1.0.0
# homeassistant.components.xiaomi_tv
pymitv==1.4.3
@ -2896,9 +2899,6 @@ weatherflow4py==0.2.21
# homeassistant.components.webmin
webmin-xmlrpc==0.0.2
# homeassistant.components.assist_pipeline
webrtc-noise-gain==1.2.3
# homeassistant.components.whirlpool
whirlpool-sixth-sense==0.18.8

View File

@ -1603,6 +1603,9 @@ pymelcloud==2.5.9
# homeassistant.components.meteoclimatic
pymeteoclimatic==0.1.0
# homeassistant.components.assist_pipeline
pymicro-vad==1.0.0
# homeassistant.components.mochad
pymochad==0.2.0
@ -2282,9 +2285,6 @@ weatherflow4py==0.2.21
# homeassistant.components.webmin
webmin-xmlrpc==0.0.2
# homeassistant.components.assist_pipeline
webrtc-noise-gain==1.2.3
# homeassistant.components.whirlpool
whirlpool-sixth-sense==0.18.8

View File

@ -172,7 +172,6 @@ EXCEPTIONS = {
"tapsaff", # https://github.com/bazwilliams/python-taps-aff/pull/5
"tellduslive", # https://github.com/molobrakos/tellduslive/pull/24
"tellsticknet", # https://github.com/molobrakos/tellsticknet/pull/33
"webrtc_noise_gain", # https://github.com/rhasspy/webrtc-noise-gain/pull/24
"vincenty", # Public domain
"zeversolar", # https://github.com/kvanzuijlen/zeversolar/pull/46
}

View File

@ -75,9 +75,7 @@ async def test_pipeline_from_audio_stream_auto(
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
assert process_events(events) == snapshot
@ -140,9 +138,7 @@ async def test_pipeline_from_audio_stream_legacy(
),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
assert process_events(events) == snapshot
@ -205,9 +201,7 @@ async def test_pipeline_from_audio_stream_entity(
),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
assert process_events(events) == snapshot
@ -271,9 +265,7 @@ async def test_pipeline_from_audio_stream_no_stt(
),
stt_stream=audio_data(),
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
assert not events
@ -335,24 +327,25 @@ async def test_pipeline_from_audio_stream_wake_word(
# [0, 2, ...]
wake_chunk_2 = bytes(it.islice(it.cycle(range(0, 256, 2)), BYTES_ONE_SECOND))
bytes_per_chunk = int(0.01 * BYTES_ONE_SECOND)
samples_per_chunk = 160
bytes_per_chunk = samples_per_chunk * 2 # 16-bit
async def audio_data():
# 1 second in 10 ms chunks
# 1 second in chunks
i = 0
while i < len(wake_chunk_1):
yield wake_chunk_1[i : i + bytes_per_chunk]
i += bytes_per_chunk
# 1 second in 30 ms chunks
# 1 second in chunks
i = 0
while i < len(wake_chunk_2):
yield wake_chunk_2[i : i + bytes_per_chunk]
i += bytes_per_chunk
yield b"wake word!"
yield b"part1"
yield b"part2"
for chunk in (b"wake word!", b"part1", b"part2"):
yield chunk + bytes(bytes_per_chunk - len(chunk))
yield b""
await assist_pipeline.async_pipeline_from_audio_stream(
@ -373,7 +366,7 @@ async def test_pipeline_from_audio_stream_wake_word(
audio_seconds_to_buffer=1.5
),
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
is_vad_enabled=False, samples_per_chunk=samples_per_chunk
),
)
@ -390,7 +383,9 @@ async def test_pipeline_from_audio_stream_wake_word(
)
assert first_chunk == wake_chunk_1[len(wake_chunk_1) // 2 :] + wake_chunk_2
assert mock_stt_provider.received[-3:] == [b"queued audio", b"part1", b"part2"]
assert mock_stt_provider.received[-3] == b"queued audio"
assert mock_stt_provider.received[-2].startswith(b"part1")
assert mock_stt_provider.received[-1].startswith(b"part2")
async def test_pipeline_save_audio(
@ -438,9 +433,7 @@ async def test_pipeline_save_audio(
pipeline_id=pipeline.id,
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.STT,
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
pipeline_dirs = list(temp_dir.iterdir())
@ -685,9 +678,7 @@ async def test_wake_word_detection_aborted(
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
),
)
await pipeline_input.validate()

View File

@ -1,11 +1,9 @@
"""Tests for voice command segmenter."""
import itertools as it
from unittest.mock import patch
from homeassistant.components.assist_pipeline.vad import (
AudioBuffer,
VoiceActivityDetector,
VoiceCommandSegmenter,
chunk_samples,
)
@ -44,59 +42,41 @@ def test_speech() -> None:
def test_audio_buffer() -> None:
"""Test audio buffer wrapping."""
class DisabledVad(VoiceActivityDetector):
def is_speech(self, chunk):
return False
samples_per_chunk = 160 # 10 ms
bytes_per_chunk = samples_per_chunk * 2
leftover_buffer = AudioBuffer(bytes_per_chunk)
@property
def samples_per_chunk(self):
return 160 # 10 ms
# Partially fill audio buffer
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
chunks = list(chunk_samples(half_chunk, bytes_per_chunk, leftover_buffer))
vad = DisabledVad()
bytes_per_chunk = vad.samples_per_chunk * 2
vad_buffer = AudioBuffer(bytes_per_chunk)
segmenter = VoiceCommandSegmenter()
assert not chunks
assert leftover_buffer.bytes() == half_chunk
with patch.object(vad, "is_speech", return_value=False) as mock_process:
# Partially fill audio buffer
half_chunk = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk // 2))
segmenter.process_with_vad(half_chunk, vad, vad_buffer)
# Fill and wrap with 1/4 chunk left over
three_quarters_chunk = bytes(
it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk))
)
chunks = list(chunk_samples(three_quarters_chunk, bytes_per_chunk, leftover_buffer))
assert not mock_process.called
assert vad_buffer is not None
assert vad_buffer.bytes() == half_chunk
assert len(chunks) == 1
assert (
leftover_buffer.bytes()
== three_quarters_chunk[len(three_quarters_chunk) - (bytes_per_chunk // 4) :]
)
assert chunks[0] == half_chunk + three_quarters_chunk[: bytes_per_chunk // 2]
# Fill and wrap with 1/4 chunk left over
three_quarters_chunk = bytes(
it.islice(it.cycle(range(256)), int(0.75 * bytes_per_chunk))
)
segmenter.process_with_vad(three_quarters_chunk, vad, vad_buffer)
# Run 2 chunks through
leftover_buffer.clear()
assert len(leftover_buffer) == 0
assert mock_process.call_count == 1
assert (
vad_buffer.bytes()
== three_quarters_chunk[
len(three_quarters_chunk) - (bytes_per_chunk // 4) :
]
)
assert (
mock_process.call_args[0][0]
== half_chunk + three_quarters_chunk[: bytes_per_chunk // 2]
)
two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2))
chunks = list(chunk_samples(two_chunks, bytes_per_chunk, leftover_buffer))
# Run 2 chunks through
segmenter.reset()
vad_buffer.clear()
assert len(vad_buffer) == 0
mock_process.reset_mock()
two_chunks = bytes(it.islice(it.cycle(range(256)), bytes_per_chunk * 2))
segmenter.process_with_vad(two_chunks, vad, vad_buffer)
assert mock_process.call_count == 2
assert len(vad_buffer) == 0
assert mock_process.call_args_list[0][0][0] == two_chunks[:bytes_per_chunk]
assert mock_process.call_args_list[1][0][0] == two_chunks[bytes_per_chunk:]
assert len(chunks) == 2
assert len(leftover_buffer) == 0
assert chunks[0] == two_chunks[:bytes_per_chunk]
assert chunks[1] == two_chunks[bytes_per_chunk:]
def test_partial_chunk() -> None:
@ -125,43 +105,3 @@ def test_chunk_samples_leftover() -> None:
assert len(chunks) == 1
assert leftover_chunk_buffer.bytes() == bytes([5, 6])
def test_vad_no_chunking() -> None:
"""Test VAD that doesn't require chunking."""
class VadNoChunk(VoiceActivityDetector):
def is_speech(self, chunk: bytes) -> bool:
return sum(chunk) > 0
@property
def samples_per_chunk(self) -> int | None:
return None
vad = VadNoChunk()
segmenter = VoiceCommandSegmenter(
speech_seconds=1.0, silence_seconds=1.0, reset_seconds=0.5
)
silence = bytes([0] * 16000)
speech = bytes([255] * (16000 // 2))
# Test with differently-sized chunks
assert vad.is_speech(speech)
assert not vad.is_speech(silence)
# Simulate voice command
assert segmenter.process_with_vad(silence, vad, None)
# begin
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
# reset with silence
assert segmenter.process_with_vad(silence, vad, None)
# resume
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
assert segmenter.process_with_vad(speech, vad, None)
# end
assert segmenter.process_with_vad(silence, vad, None)
assert not segmenter.process_with_vad(silence, vad, None)

View File

@ -259,12 +259,7 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"timeout": 0,
"no_vad": True,
"no_chunking": True,
},
"input": {"sample_rate": 16000, "timeout": 0, "no_vad": True},
}
)
@ -1876,11 +1871,7 @@ async def test_wake_word_cooldown_same_id(
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"input": {"sample_rate": 16000, "no_vad": True},
}
)
@ -1889,11 +1880,7 @@ async def test_wake_word_cooldown_same_id(
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"input": {"sample_rate": 16000, "no_vad": True},
}
)
@ -1967,11 +1954,7 @@ async def test_wake_word_cooldown_different_ids(
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"input": {"sample_rate": 16000, "no_vad": True},
}
)
@ -1980,11 +1963,7 @@ async def test_wake_word_cooldown_different_ids(
"type": "assist_pipeline/run",
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"input": {"sample_rate": 16000, "no_vad": True},
}
)
@ -2094,11 +2073,7 @@ async def test_wake_word_cooldown_different_entities(
"pipeline": pipeline_id_1,
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"input": {"sample_rate": 16000, "no_vad": True},
}
)
@ -2109,11 +2084,7 @@ async def test_wake_word_cooldown_different_entities(
"pipeline": pipeline_id_2,
"start_stage": "wake_word",
"end_stage": "tts",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"input": {"sample_rate": 16000, "no_vad": True},
}
)
@ -2210,11 +2181,7 @@ async def test_device_capture(
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "stt",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"input": {"sample_rate": 16000, "no_vad": True},
"device_id": satellite_device.id,
}
)
@ -2315,11 +2282,7 @@ async def test_device_capture_override(
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "stt",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"input": {"sample_rate": 16000, "no_vad": True},
"device_id": satellite_device.id,
}
)
@ -2464,11 +2427,7 @@ async def test_device_capture_queue_full(
"type": "assist_pipeline/run",
"start_stage": "stt",
"end_stage": "stt",
"input": {
"sample_rate": 16000,
"no_vad": True,
"no_chunking": True,
},
"input": {"sample_rate": 16000, "no_vad": True},
"device_id": satellite_device.id,
}
)

View File

@ -43,9 +43,12 @@ async def test_pipeline(
"""Test that pipeline function is called from RTP protocol."""
assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk):
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
return sum(chunk) > 0
if sum(chunk) > 0:
return 1
return 0
done = asyncio.Event()
@ -98,8 +101,8 @@ async def test_pipeline(
with (
patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech,
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
@ -238,9 +241,12 @@ async def test_tts_timeout(
"""Test that TTS will time out based on its length."""
assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk):
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
return sum(chunk) > 0
if sum(chunk) > 0:
return 1
return 0
done = asyncio.Event()
@ -298,8 +304,8 @@ async def test_tts_timeout(
with (
patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech,
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
@ -361,9 +367,12 @@ async def test_tts_wrong_extension(
"""Test that TTS will only stream WAV audio."""
assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk):
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
return sum(chunk) > 0
if sum(chunk) > 0:
return 1
return 0
done = asyncio.Event()
@ -403,8 +412,8 @@ async def test_tts_wrong_extension(
with (
patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech,
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
@ -456,9 +465,12 @@ async def test_tts_wrong_wav_format(
"""Test that TTS will only stream WAV audio with a specific format."""
assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk):
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
return sum(chunk) > 0
if sum(chunk) > 0:
return 1
return 0
done = asyncio.Event()
@ -505,8 +517,8 @@ async def test_tts_wrong_wav_format(
with (
patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech,
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
@ -558,9 +570,12 @@ async def test_empty_tts_output(
"""Test that TTS will not stream when output is empty."""
assert await async_setup_component(hass, "voip", {})
def is_speech(self, chunk):
def process_10ms(self, chunk):
"""Anything non-zero is speech."""
return sum(chunk) > 0
if sum(chunk) > 0:
return 1
return 0
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
@ -591,8 +606,8 @@ async def test_empty_tts_output(
with (
patch(
"homeassistant.components.assist_pipeline.vad.WebRtcVad.is_speech",
new=is_speech,
"pymicro_vad.MicroVad.Process10ms",
new=process_10ms,
),
patch(
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",