diff --git a/homeassistant/components/websocket_api/connection.py b/homeassistant/components/websocket_api/connection.py index 63b4418a19d..14ceac60e59 100644 --- a/homeassistant/components/websocket_api/connection.py +++ b/homeassistant/components/websocket_api/connection.py @@ -4,7 +4,7 @@ from __future__ import annotations from collections.abc import Callable, Hashable from contextvars import ContextVar -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from aiohttp import web import voluptuous as vol @@ -65,9 +65,9 @@ class ActiveConnection: self.last_id = 0 self.can_coalesce = False self.supported_features: dict[str, float] = {} - self.handlers: dict[str, tuple[MessageHandler, vol.Schema]] = self.hass.data[ - const.DOMAIN - ] + self.handlers: dict[str, tuple[MessageHandler, vol.Schema | Literal[False]]] = ( + self.hass.data[const.DOMAIN] + ) self.binary_handlers: list[BinaryHandler | None] = [] current_connection.set(self) @@ -185,6 +185,7 @@ class ActiveConnection: or ( not (cur_id := msg.get("id")) or type(cur_id) is not int # noqa: E721 + or cur_id < 0 or not (type_ := msg.get("type")) or type(type_) is not str # noqa: E721 ) @@ -220,7 +221,7 @@ class ActiveConnection: handler, schema = handler_schema try: - handler(self.hass, self, schema(msg)) + handler(self.hass, self, msg if schema is False else schema(msg)) except Exception as err: # pylint: disable=broad-except self.async_handle_exception(msg, err) diff --git a/homeassistant/components/websocket_api/decorators.py b/homeassistant/components/websocket_api/decorators.py index 51643752a0f..0ed8be30139 100644 --- a/homeassistant/components/websocket_api/decorators.py +++ b/homeassistant/components/websocket_api/decorators.py @@ -4,7 +4,7 @@ from __future__ import annotations from collections.abc import Callable from functools import wraps -from typing import Any +from typing import TYPE_CHECKING, Any import voluptuous as vol @@ -137,7 +137,7 @@ def websocket_command( The schema must be either a dictionary where the keys are voluptuous markers, or a voluptuous.All schema where the first item is a voluptuous Mapping schema. """ - if isinstance(schema, dict): + if is_dict := isinstance(schema, dict): command = schema["type"] else: command = schema.validators[0].schema["type"] @@ -145,9 +145,13 @@ def websocket_command( def decorate(func: const.WebSocketCommandHandler) -> const.WebSocketCommandHandler: """Decorate ws command function.""" # pylint: disable=protected-access - if isinstance(schema, dict): + if is_dict and len(schema) == 1: # type only empty schema + func._ws_schema = False # type: ignore[attr-defined] + elif is_dict: func._ws_schema = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend(schema) # type: ignore[attr-defined] else: + if TYPE_CHECKING: + assert not isinstance(schema, dict) extended_schema = vol.All( schema.validators[0].extend( messages.BASE_COMMAND_MESSAGE_SCHEMA.schema