mirror of
https://github.com/home-assistant/core.git
synced 2025-06-25 01:21:51 +02:00
Enable strict typing of assist_pipeline (#91529)
This commit is contained in:
@ -61,6 +61,7 @@ homeassistant.components.anthemav.*
|
||||
homeassistant.components.apcupsd.*
|
||||
homeassistant.components.aqualogic.*
|
||||
homeassistant.components.aseko_pool_live.*
|
||||
homeassistant.components.assist_pipeline.*
|
||||
homeassistant.components.asuswrt.*
|
||||
homeassistant.components.auth.*
|
||||
homeassistant.components.automation.*
|
||||
|
@ -179,7 +179,7 @@ class PipelineRun:
|
||||
tts_engine: str | None = None
|
||||
tts_options: dict | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
"""Set language for pipeline."""
|
||||
self.language = self.pipeline.language or self.hass.config.language
|
||||
|
||||
@ -189,7 +189,7 @@ class PipelineRun:
|
||||
):
|
||||
raise InvalidPipelineStagesError(self.start_stage, self.end_stage)
|
||||
|
||||
def start(self):
|
||||
def start(self) -> None:
|
||||
"""Emit run start event."""
|
||||
data = {
|
||||
"pipeline": self.pipeline.name,
|
||||
@ -200,7 +200,7 @@ class PipelineRun:
|
||||
|
||||
self.event_callback(PipelineEvent(PipelineEventType.RUN_START, data))
|
||||
|
||||
def end(self):
|
||||
def end(self) -> None:
|
||||
"""Emit run end event."""
|
||||
self.event_callback(
|
||||
PipelineEvent(
|
||||
@ -349,7 +349,9 @@ class PipelineRun:
|
||||
)
|
||||
)
|
||||
|
||||
speech = conversation_result.response.speech.get("plain", {}).get("speech", "")
|
||||
speech: str = conversation_result.response.speech.get("plain", {}).get(
|
||||
"speech", ""
|
||||
)
|
||||
|
||||
return speech
|
||||
|
||||
@ -453,7 +455,7 @@ class PipelineInput:
|
||||
|
||||
conversation_id: str | None = None
|
||||
|
||||
async def execute(self):
|
||||
async def execute(self) -> None:
|
||||
"""Run pipeline."""
|
||||
self.run.start()
|
||||
current_stage = self.run.start_stage
|
||||
@ -496,7 +498,7 @@ class PipelineInput:
|
||||
|
||||
self.run.end()
|
||||
|
||||
async def validate(self):
|
||||
async def validate(self) -> None:
|
||||
"""Validate pipeline input against start stage."""
|
||||
if self.run.start_stage == PipelineStage.STT:
|
||||
if self.stt_metadata is None:
|
||||
@ -524,7 +526,8 @@ class PipelineInput:
|
||||
prepare_tasks = []
|
||||
|
||||
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.STT):
|
||||
prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata))
|
||||
# self.stt_metadata can't be None or we'd raise above
|
||||
prepare_tasks.append(self.run.prepare_speech_to_text(self.stt_metadata)) # type: ignore[arg-type]
|
||||
|
||||
if start_stage_index <= PIPELINE_STAGE_ORDER.index(PipelineStage.INTENT):
|
||||
prepare_tasks.append(self.run.prepare_recognize_intent())
|
||||
@ -696,7 +699,7 @@ class PipelineStorageCollectionWebsocket(
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
async def async_setup_pipeline_store(hass):
|
||||
async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
|
||||
"""Set up the pipeline storage collection."""
|
||||
pipeline_store = PipelineStorageCollection(
|
||||
Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||
|
@ -48,14 +48,14 @@ class VoiceCommandSegmenter:
|
||||
_bytes_per_chunk: int = 480 * 2 # 16-bit samples
|
||||
_seconds_per_chunk: float = 0.03 # 30 ms
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
"""Initialize VAD."""
|
||||
self._vad = webrtcvad.Vad(self.vad_mode)
|
||||
self._bytes_per_chunk = self.vad_frames * 2
|
||||
self._seconds_per_chunk = self.vad_frames / _SAMPLE_RATE
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Reset all counters and state."""
|
||||
self._audio_buffer = b""
|
||||
self._speech_seconds_left = self.speech_seconds
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Assist pipeline Websocket API."""
|
||||
import asyncio
|
||||
import audioop # pylint: disable=deprecated-module
|
||||
from collections.abc import Callable
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@ -114,7 +114,7 @@ async def websocket_run(
|
||||
audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
|
||||
incoming_sample_rate = msg["input"]["sample_rate"]
|
||||
|
||||
async def stt_stream():
|
||||
async def stt_stream() -> AsyncGenerator[bytes, None]:
|
||||
state = None
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
|
||||
@ -129,7 +129,11 @@ async def websocket_run(
|
||||
|
||||
yield chunk
|
||||
|
||||
def handle_binary(_hass, _connection, data: bytes):
|
||||
def handle_binary(
|
||||
_hass: HomeAssistant,
|
||||
_connection: websocket_api.ActiveConnection,
|
||||
data: bytes,
|
||||
) -> None:
|
||||
# Forward to STT audio stream
|
||||
audio_queue.put_nowait(data)
|
||||
|
||||
|
10
mypy.ini
10
mypy.ini
@ -371,6 +371,16 @@ disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.assist_pipeline.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.asuswrt.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
|
Reference in New Issue
Block a user