Compare commits

...

2 Commits

Author SHA1 Message Date
Petar Petrov adeae40ce1 Strip unknown labels instead of erroring 2026-06-11 11:00:57 +03:00
Petar Petrov a6d3fb1808 Reject unknown label ids in registry websocket APIs 2026-06-10 14:01:04 +03:00
8 changed files with 224 additions and 14 deletions
@@ -6,7 +6,7 @@ import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import area_registry as ar
from homeassistant.helpers import area_registry as ar, label_registry as lr
@callback
@@ -69,8 +69,9 @@ def websocket_create_area(
data["aliases"] = {s_strip for s in data["aliases"] if (s_strip := s.strip())}
if "labels" in data:
# Convert labels to a set
data["labels"] = set(data["labels"])
# Strip labels which are not in the label registry
labels = set(data["labels"])
data["labels"] = labels - lr.async_get_missing_label_ids(hass, labels)
try:
entry = registry.async_create(**data)
@@ -139,8 +140,11 @@ def websocket_update_area(
data["aliases"] = {s_strip for s in data["aliases"] if (s_strip := s.strip())}
if "labels" in data:
# Convert labels to a set
data["labels"] = set(data["labels"])
# Strip labels which are not in the label registry. This also cleans up
# any stale labels already stored on the area (e.g. left behind by a
# deleted label) the next time it is saved.
labels = set(data["labels"])
data["labels"] = labels - lr.async_get_missing_label_ids(hass, labels)
try:
entry = registry.async_update(**data)
@@ -9,7 +9,7 @@ from homeassistant.components import websocket_api
from homeassistant.components.websocket_api import require_admin
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers import device_registry as dr, label_registry as lr
from homeassistant.helpers.device_registry import DeviceEntry, DeviceEntryDisabler
@@ -84,8 +84,11 @@ def websocket_update_device(
msg["disabled_by"] = DeviceEntryDisabler(msg["disabled_by"])
if "labels" in msg:
# Convert labels to a set
msg["labels"] = set(msg["labels"])
# Strip labels which are not in the label registry. This also cleans up
# any stale labels already stored on the device (e.g. left behind by a
# deleted label) the next time it is saved.
labels = set(msg["labels"])
msg["labels"] = labels - lr.async_get_missing_label_ids(hass, labels)
entry = cast(DeviceEntry, registry.async_update_device(**msg))
@@ -13,6 +13,7 @@ from homeassistant.helpers import (
config_validation as cv,
device_registry as dr,
entity_registry as er,
label_registry as lr,
)
from homeassistant.helpers.json import json_dumps
@@ -234,8 +235,11 @@ def websocket_update_entity(
aliases.append(alias)
if "labels" in msg:
# Convert labels to a set
changes["labels"] = set(msg["labels"])
# Strip labels which are not in the label registry. This also cleans up
# any stale labels already stored on the entity (e.g. left behind by a
# deleted label) the next time it is saved.
labels = set(msg["labels"])
changes["labels"] = labels - lr.async_get_missing_label_ids(hass, labels)
if "disabled_by" in msg and msg["disabled_by"] is None:
# Don't allow enabling an entity of a disabled device
+11
View File
@@ -268,6 +268,17 @@ def async_get(hass: HomeAssistant) -> LabelRegistry:
return LabelRegistry(hass)
@callback
def async_get_missing_label_ids(
hass: HomeAssistant, label_ids: Iterable[str]
) -> set[str]:
"""Return the label ids which are missing from the label registry."""
registry = async_get(hass)
return {
label_id for label_id in label_ids if registry.async_get_label(label_id) is None
}
async def async_load(hass: HomeAssistant, *, load_empty: bool = False) -> None:
"""Load label registry."""
assert DATA_REGISTRY not in hass.data
+70 -1
View File
@@ -16,7 +16,7 @@ from homeassistant.const import (
UnitOfTemperature,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import area_registry as ar
from homeassistant.helpers import area_registry as ar, label_registry as lr
from homeassistant.util.dt import utcnow
from tests.common import ANY
@@ -113,10 +113,13 @@ async def test_list_areas(
async def test_create_area(
client: MockHAClientWebSocket,
area_registry: ar.AreaRegistry,
label_registry: lr.LabelRegistry,
freezer: FrozenDateTimeFactory,
mock_temperature_humidity_entity: None,
) -> None:
"""Test create entry."""
label_registry.async_create("label_1")
label_registry.async_create("label_2")
# Create area with only mandatory parameters
await client.send_json_auto_id(
{"name": "mock", "type": "config/area_registry/create"}
@@ -261,10 +264,13 @@ async def test_delete_non_existing_area(
async def test_update_area(
client: MockHAClientWebSocket,
area_registry: ar.AreaRegistry,
label_registry: lr.LabelRegistry,
freezer: FrozenDateTimeFactory,
mock_temperature_humidity_entity: None,
) -> None:
"""Test update entry."""
label_registry.async_create("label_1")
label_registry.async_create("label_2")
created_at = datetime.fromisoformat("2024-07-16T13:30:00.900075+00:00")
freezer.move_to(created_at)
area = area_registry.async_create("mock 1")
@@ -372,6 +378,69 @@ async def test_update_area(
assert len(area_registry.areas) == 1
async def test_create_area_strips_unknown_labels(
client: MockHAClientWebSocket,
area_registry: ar.AreaRegistry,
label_registry: lr.LabelRegistry,
) -> None:
"""Test labels not in the label registry are stripped when creating an area."""
label_registry.async_create("label_1")
await client.send_json_auto_id(
{
"type": "config/area_registry/create",
"name": "mock",
"labels": ["label_1", "missing"],
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"]["labels"] == ["label_1"]
assert area_registry.async_get_area(msg["result"]["area_id"]).labels == {"label_1"}
@pytest.mark.parametrize(
("labels", "expected_labels"),
[
pytest.param(["label_1", "missing"], {"label_1"}, id="strip_unknown"),
pytest.param(["label_1", "stale_label"], {"label_1"}, id="strip_stale_resent"),
pytest.param(["stale_label", "missing"], set(), id="strip_all_unknown"),
pytest.param([], set(), id="remove_all"),
],
)
async def test_update_area_strips_unknown_labels(
client: MockHAClientWebSocket,
area_registry: ar.AreaRegistry,
label_registry: lr.LabelRegistry,
labels: list[str],
expected_labels: set[str],
) -> None:
"""Test labels not in the label registry are stripped on update.
A stale label already stored on the area is cleaned up when the area is
next saved, even if the client sends it back.
"""
# Seed a stale label via the helper layer, bypassing WS stripping
area = area_registry.async_create("mock", labels={"stale_label"})
label_registry.async_create("label_1")
await client.send_json_auto_id(
{
"type": "config/area_registry/update",
"area_id": area.id,
"labels": labels,
}
)
msg = await client.receive_json()
assert msg["success"]
assert set(msg["result"]["labels"]) == expected_labels
assert area_registry.async_get_area(area.id).labels == expected_labels
async def test_update_area_with_same_name(
client: MockHAClientWebSocket, area_registry: ar.AreaRegistry
) -> None:
@@ -9,7 +9,7 @@ from pytest_unordered import unordered
from homeassistant.components.config import DOMAIN, device_registry
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers import device_registry as dr, label_registry as lr
from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow
@@ -216,9 +216,12 @@ async def test_update_device_labels(
hass: HomeAssistant,
client: MockHAClientWebSocket,
device_registry: dr.DeviceRegistry,
label_registry: lr.LabelRegistry,
freezer: FrozenDateTimeFactory,
) -> None:
"""Test update entry labels."""
label_registry.async_create("label1")
label_registry.async_create("label2")
entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
created_at = datetime.fromisoformat("2024-07-16T13:30:00.900075+00:00")
@@ -262,6 +265,53 @@ async def test_update_device_labels(
assert getattr(device, key) == value
@pytest.mark.parametrize(
("labels", "expected_labels"),
[
pytest.param(["label1", "missing"], {"label1"}, id="strip_unknown"),
pytest.param(["label1", "stale_label"], {"label1"}, id="strip_stale_resent"),
pytest.param(["stale_label", "missing"], set(), id="strip_all_unknown"),
pytest.param([], set(), id="remove_all"),
],
)
async def test_update_device_strips_unknown_labels(
hass: HomeAssistant,
client: MockHAClientWebSocket,
device_registry: dr.DeviceRegistry,
label_registry: lr.LabelRegistry,
labels: list[str],
expected_labels: set[str],
) -> None:
"""Test labels not in the label registry are stripped on update.
A stale label already stored on the device is cleaned up when the device
is next saved, even if the client sends it back.
"""
entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
identifiers={("bridgeid", "0123")},
)
# Seed a stale label via the helper layer, bypassing WS stripping
device_registry.async_update_device(device.id, labels={"stale_label"})
label_registry.async_create("label1")
await client.send_json_auto_id(
{
"type": "config/device_registry/update",
"device_id": device.id,
"labels": labels,
}
)
msg = await client.receive_json()
assert msg["success"]
assert set(msg["result"]["labels"]) == expected_labels
assert device_registry.async_get(device.id).labels == expected_labels
async def test_remove_config_entry_from_device(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
@@ -10,7 +10,11 @@ from pytest_unordered import unordered
from homeassistant.components.config import entity_registry
from homeassistant.const import ATTR_ICON, EntityCategory
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers import (
device_registry as dr,
entity_registry as er,
label_registry as lr,
)
from homeassistant.helpers.device_registry import DeviceEntryDisabler
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.entity_registry import (
@@ -607,9 +611,14 @@ async def test_get_entities(hass: HomeAssistant, client: MockHAClientWebSocket)
async def test_update_entity(
hass: HomeAssistant, client: MockHAClientWebSocket, freezer: FrozenDateTimeFactory
hass: HomeAssistant,
client: MockHAClientWebSocket,
freezer: FrozenDateTimeFactory,
label_registry: lr.LabelRegistry,
) -> None:
"""Test updating entity."""
label_registry.async_create("label1")
label_registry.async_create("label2")
created = datetime.fromisoformat("2024-02-14T12:00:00.900075+00:00")
freezer.move_to(created)
registry = mock_registry(
@@ -999,6 +1008,55 @@ async def test_update_entity(
}
@pytest.mark.parametrize(
("labels", "expected_labels"),
[
pytest.param(["label1", "missing"], {"label1"}, id="strip_unknown"),
pytest.param(["label1", "stale_label"], {"label1"}, id="strip_stale_resent"),
pytest.param(["stale_label", "missing"], set(), id="strip_all_unknown"),
pytest.param([], set(), id="remove_all"),
],
)
async def test_update_entity_strips_unknown_labels(
hass: HomeAssistant,
client: MockHAClientWebSocket,
label_registry: lr.LabelRegistry,
labels: list[str],
expected_labels: set[str],
) -> None:
"""Test labels not in the label registry are stripped on update.
A stale label already stored on the entity is cleaned up when the entity
is next saved, even if the client sends it back.
"""
registry = mock_registry(
hass,
{
"test_domain.world": RegistryEntryWithDefaults(
entity_id="test_domain.world",
unique_id="1234",
platform="test_platform",
labels={"stale_label"}, # not in the label registry
)
},
)
label_registry.async_create("label1")
await client.send_json_auto_id(
{
"type": "config/entity_registry/update",
"entity_id": "test_domain.world",
"labels": labels,
}
)
msg = await client.receive_json()
assert msg["success"]
assert set(msg["result"]["entity_entry"]["labels"]) == expected_labels
assert registry.entities["test_domain.world"].labels == expected_labels
async def test_update_entity_require_restart(
hass: HomeAssistant, client: MockHAClientWebSocket, freezer: FrozenDateTimeFactory
) -> None:
+11
View File
@@ -558,3 +558,14 @@ async def test_migration_from_1_1(
]
},
}
async def test_async_get_missing_label_ids(
hass: HomeAssistant, label_registry: lr.LabelRegistry
) -> None:
"""Test getting label ids missing from the registry."""
label_registry.async_create("mock")
assert lr.async_get_missing_label_ids(hass, set()) == set()
assert lr.async_get_missing_label_ids(hass, {"mock"}) == set()
assert lr.async_get_missing_label_ids(hass, {"mock", "missing"}) == {"missing"}