Enable strict typing of assist_pipeline (#91529)

This commit is contained in:
Erik Montnemery
2023-04-17 10:32:14 +02:00
committed by GitHub
parent 9985516f80
commit 3367e86686
5 changed files with 31 additions and 13 deletions

View File

@ -61,6 +61,7 @@ homeassistant.components.anthemav.*
homeassistant.components.apcupsd.*
homeassistant.components.aqualogic.*
homeassistant.components.aseko_pool_live.*
homeassistant.components.assist_pipeline.*
homeassistant.components.asuswrt.*
homeassistant.components.auth.*
homeassistant.components.automation.*

View File

@ -179,7 +179,7 @@ class PipelineRun:
tts_engine: str | None = None
tts_options: dict | None = None
def __post_init__(self):
def __post_init__(self) -> None:
"""Set language for pipeline."""
self.language = self.pipeline.language or self.hass.config.language
@ -189,7 +189,7 @@ class PipelineRun:
):
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
def start(self):
def start(self) -> None:
"""Emit run start event."""
data = {
"pipeline": self.pipeline.name,
@ -200,7 +200,7 @@ class PipelineRun:
self.event_callback(PipelineEvent(PipelineEventType.RUN_START, data))
def end(self):
def end(self) -> None:
"""Emit run end event."""
self.event_callback(
PipelineEvent(
@ -349,7 +349,9 @@ class PipelineRun:
)
)
speech = conversation_result.response.speech.get("plain", {}).get("speech", "")
speech: str = conversation_result.response.speech.get("plain", {}).get(
"speech", ""
)
return speech
@ -453,7 +455,7 @@ class PipelineInput:
conversation_id: str | None = None
async def execute(self):
async def execute(self) -> None:
"""Run pipeline."""
self.run.start()
current_stage = self.run.start_stage
@ -496,7 +498,7 @@ class PipelineInput:
self.run.end()
async def validate(self):
async def validate(self) -> None:
"""Validate pipeline input against start stage."""
if self.run.start_stage == PipelineStage.STT:
if self.stt_metadata is None:
@ -524,7 +526,8 @@ class PipelineInput:
prepare_tasks = []
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.STT):
prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata))
# self.stt_metadata can't be None or we'd raise above
prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata)) # type: ignore[arg-type]
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT):
prepare_tasks.append(self.run.prepare_recognize_intent())
@ -696,7 +699,7 @@ class PipelineStorageCollectionWebsocket(
connection.send_result(msg["id"])
async def async_setup_pipeline_store(hass):
async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
"""Set up the pipeline storage collection."""
pipeline_store = PipelineStorageCollection(
Store(hass, STORAGE_VERSION, STORAGE_KEY)

View File

@ -48,14 +48,14 @@ class VoiceCommandSegmenter:
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
_seconds_per_chunk: float = 0.03 # 30 ms
def __post_init__(self):
def __post_init__(self) -> None:
"""Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_frames * 2
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
self.reset()
def reset(self):
def reset(self) -> None:
"""Reset all counters and state."""
self._audio_buffer = b""
self._speech_seconds_left = self.speech_seconds

View File

@ -1,7 +1,7 @@
"""Assist pipeline Websocket API."""
import asyncio
import audioop # pylint: disable=deprecated-module
from collections.abc import Callable
from collections.abc import AsyncGenerator, Callable
import logging
from typing import Any
@ -114,7 +114,7 @@ async def websocket_run(
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
incoming_sample_rate = msg["input"]["sample_rate"]
async def stt_stream():
async def stt_stream() -> AsyncGenerator[bytes, None]:
state = None
segmenter = VoiceCommandSegmenter()
@ -129,7 +129,11 @@ async def websocket_run(
yield chunk
def handle_binary(_hass, _connection, data: bytes):
def handle_binary(
_hass: HomeAssistant,
_connection: websocket_api.ActiveConnection,
data: bytes,
) -> None:
# Forward to STT audio stream
audio_queue.put_nowait(data)

View File

@ -371,6 +371,16 @@ disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.assist_pipeline.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.asuswrt.*]
check_untyped_defs = true
disallow_incomplete_defs = true