Simplify WS command entity/source (#99439)

This commit is contained in:
Erik Montnemery
2023-09-12 15:39:11 +02:00
committed by GitHub
parent e143bdf2f5
commit fabb098ec3
2 changed files with 26 additions and 120 deletions

View File

@ -11,7 +11,7 @@ from typing import Any, cast
import voluptuous as vol
from homeassistant.auth.models import User
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ
from homeassistant.auth.permissions.const import POLICY_READ
from homeassistant.const import (
EVENT_STATE_CHANGED,
MATCH_ALL,
@ -52,7 +52,6 @@ from homeassistant.util.json import format_unserializable_data
from . import const, decorators, messages
from .connection import ActiveConnection
from .const import ERR_NOT_FOUND
from .messages import construct_event_message, construct_result_message
ALL_SERVICE_DESCRIPTIONS_JSON_CACHE = "websocket_api_all_service_descriptions_json"
@ -596,47 +595,35 @@ async def handle_render_template(
hass.loop.call_soon_threadsafe(info.async_refresh)
def _serialize_entity_sources(
entity_infos: dict[str, dict[str, str]]
) -> dict[str, Any]:
"""Prepare a websocket response from a dict of entity sources."""
result = {}
for entity_id, entity_info in entity_infos.items():
result[entity_id] = {"domain": entity_info["domain"]}
return result
@callback
@decorators.websocket_command(
{vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]}
)
@decorators.websocket_command({vol.Required("type"): "entity/source"})
def handle_entity_source(
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle entity source command."""
raw_sources = entity.entity_sources(hass)
all_entity_sources = entity.entity_sources(hass)
entity_perm = connection.user.permissions.check_entity
if "entity_id" not in msg:
if connection.user.permissions.access_all_entities(POLICY_READ):
sources = raw_sources
else:
sources = {
entity_id: source
for entity_id, source in raw_sources.items()
if entity_perm(entity_id, POLICY_READ)
}
if connection.user.permissions.access_all_entities(POLICY_READ):
entity_sources = all_entity_sources
else:
entity_sources = {
entity_id: source
for entity_id, source in all_entity_sources.items()
if entity_perm(entity_id, POLICY_READ)
}
connection.send_result(msg["id"], sources)
return
sources = {}
for entity_id in msg["entity_id"]:
if not entity_perm(entity_id, POLICY_READ):
raise Unauthorized(
context=connection.context(msg),
permission=POLICY_READ,
perm_category=CAT_ENTITIES,
)
if (source := raw_sources.get(entity_id)) is None:
connection.send_error(msg["id"], ERR_NOT_FOUND, "Entity not found")
return
sources[entity_id] = source
connection.send_result(msg["id"], sources)
connection.send_result(msg["id"], _serialize_entity_sources(entity_sources))
@decorators.websocket_command(