mirror of
https://github.com/home-assistant/core.git
synced 2026-06-11 11:41:42 +02:00
Compare commits
2 Commits
dev
...
fix-173069
| Author | SHA1 | Date | |
|---|---|---|---|
| adeae40ce1 | |||
| a6d3fb1808 |
@@ -6,7 +6,7 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import area_registry as ar
|
||||
from homeassistant.helpers import area_registry as ar, label_registry as lr
|
||||
|
||||
|
||||
@callback
|
||||
@@ -69,8 +69,9 @@ def websocket_create_area(
|
||||
data["aliases"] = {s_strip for s in data["aliases"] if (s_strip := s.strip())}
|
||||
|
||||
if "labels" in data:
|
||||
# Convert labels to a set
|
||||
data["labels"] = set(data["labels"])
|
||||
# Strip labels which are not in the label registry
|
||||
labels = set(data["labels"])
|
||||
data["labels"] = labels - lr.async_get_missing_label_ids(hass, labels)
|
||||
|
||||
try:
|
||||
entry = registry.async_create(**data)
|
||||
@@ -139,8 +140,11 @@ def websocket_update_area(
|
||||
data["aliases"] = {s_strip for s in data["aliases"] if (s_strip := s.strip())}
|
||||
|
||||
if "labels" in data:
|
||||
# Convert labels to a set
|
||||
data["labels"] = set(data["labels"])
|
||||
# Strip labels which are not in the label registry. This also cleans up
|
||||
# any stale labels already stored on the area (e.g. left behind by a
|
||||
# deleted label) the next time it is saved.
|
||||
labels = set(data["labels"])
|
||||
data["labels"] = labels - lr.async_get_missing_label_ids(hass, labels)
|
||||
|
||||
try:
|
||||
entry = registry.async_update(**data)
|
||||
|
||||
@@ -9,7 +9,7 @@ from homeassistant.components import websocket_api
|
||||
from homeassistant.components.websocket_api import require_admin
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.helpers import device_registry as dr, label_registry as lr
|
||||
from homeassistant.helpers.device_registry import DeviceEntry, DeviceEntryDisabler
|
||||
|
||||
|
||||
@@ -84,8 +84,11 @@ def websocket_update_device(
|
||||
msg["disabled_by"] = DeviceEntryDisabler(msg["disabled_by"])
|
||||
|
||||
if "labels" in msg:
|
||||
# Convert labels to a set
|
||||
msg["labels"] = set(msg["labels"])
|
||||
# Strip labels which are not in the label registry. This also cleans up
|
||||
# any stale labels already stored on the device (e.g. left behind by a
|
||||
# deleted label) the next time it is saved.
|
||||
labels = set(msg["labels"])
|
||||
msg["labels"] = labels - lr.async_get_missing_label_ids(hass, labels)
|
||||
|
||||
entry = cast(DeviceEntry, registry.async_update_device(**msg))
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from homeassistant.helpers import (
|
||||
config_validation as cv,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
label_registry as lr,
|
||||
)
|
||||
from homeassistant.helpers.json import json_dumps
|
||||
|
||||
@@ -234,8 +235,11 @@ def websocket_update_entity(
|
||||
aliases.append(alias)
|
||||
|
||||
if "labels" in msg:
|
||||
# Convert labels to a set
|
||||
changes["labels"] = set(msg["labels"])
|
||||
# Strip labels which are not in the label registry. This also cleans up
|
||||
# any stale labels already stored on the entity (e.g. left behind by a
|
||||
# deleted label) the next time it is saved.
|
||||
labels = set(msg["labels"])
|
||||
changes["labels"] = labels - lr.async_get_missing_label_ids(hass, labels)
|
||||
|
||||
if "disabled_by" in msg and msg["disabled_by"] is None:
|
||||
# Don't allow enabling an entity of a disabled device
|
||||
|
||||
@@ -268,6 +268,17 @@ def async_get(hass: HomeAssistant) -> LabelRegistry:
|
||||
return LabelRegistry(hass)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_missing_label_ids(
|
||||
hass: HomeAssistant, label_ids: Iterable[str]
|
||||
) -> set[str]:
|
||||
"""Return the label ids which are missing from the label registry."""
|
||||
registry = async_get(hass)
|
||||
return {
|
||||
label_id for label_id in label_ids if registry.async_get_label(label_id) is None
|
||||
}
|
||||
|
||||
|
||||
async def async_load(hass: HomeAssistant, *, load_empty: bool = False) -> None:
|
||||
"""Load label registry."""
|
||||
assert DATA_REGISTRY not in hass.data
|
||||
|
||||
@@ -16,7 +16,7 @@ from homeassistant.const import (
|
||||
UnitOfTemperature,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import area_registry as ar
|
||||
from homeassistant.helpers import area_registry as ar, label_registry as lr
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from tests.common import ANY
|
||||
@@ -113,10 +113,13 @@ async def test_list_areas(
|
||||
async def test_create_area(
|
||||
client: MockHAClientWebSocket,
|
||||
area_registry: ar.AreaRegistry,
|
||||
label_registry: lr.LabelRegistry,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
mock_temperature_humidity_entity: None,
|
||||
) -> None:
|
||||
"""Test create entry."""
|
||||
label_registry.async_create("label_1")
|
||||
label_registry.async_create("label_2")
|
||||
# Create area with only mandatory parameters
|
||||
await client.send_json_auto_id(
|
||||
{"name": "mock", "type": "config/area_registry/create"}
|
||||
@@ -261,10 +264,13 @@ async def test_delete_non_existing_area(
|
||||
async def test_update_area(
|
||||
client: MockHAClientWebSocket,
|
||||
area_registry: ar.AreaRegistry,
|
||||
label_registry: lr.LabelRegistry,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
mock_temperature_humidity_entity: None,
|
||||
) -> None:
|
||||
"""Test update entry."""
|
||||
label_registry.async_create("label_1")
|
||||
label_registry.async_create("label_2")
|
||||
created_at = datetime.fromisoformat("2024-07-16T13:30:00.900075+00:00")
|
||||
freezer.move_to(created_at)
|
||||
area = area_registry.async_create("mock 1")
|
||||
@@ -372,6 +378,69 @@ async def test_update_area(
|
||||
assert len(area_registry.areas) == 1
|
||||
|
||||
|
||||
async def test_create_area_strips_unknown_labels(
|
||||
client: MockHAClientWebSocket,
|
||||
area_registry: ar.AreaRegistry,
|
||||
label_registry: lr.LabelRegistry,
|
||||
) -> None:
|
||||
"""Test labels not in the label registry are stripped when creating an area."""
|
||||
label_registry.async_create("label_1")
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "config/area_registry/create",
|
||||
"name": "mock",
|
||||
"labels": ["label_1", "missing"],
|
||||
}
|
||||
)
|
||||
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg["success"]
|
||||
assert msg["result"]["labels"] == ["label_1"]
|
||||
assert area_registry.async_get_area(msg["result"]["area_id"]).labels == {"label_1"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("labels", "expected_labels"),
|
||||
[
|
||||
pytest.param(["label_1", "missing"], {"label_1"}, id="strip_unknown"),
|
||||
pytest.param(["label_1", "stale_label"], {"label_1"}, id="strip_stale_resent"),
|
||||
pytest.param(["stale_label", "missing"], set(), id="strip_all_unknown"),
|
||||
pytest.param([], set(), id="remove_all"),
|
||||
],
|
||||
)
|
||||
async def test_update_area_strips_unknown_labels(
|
||||
client: MockHAClientWebSocket,
|
||||
area_registry: ar.AreaRegistry,
|
||||
label_registry: lr.LabelRegistry,
|
||||
labels: list[str],
|
||||
expected_labels: set[str],
|
||||
) -> None:
|
||||
"""Test labels not in the label registry are stripped on update.
|
||||
|
||||
A stale label already stored on the area is cleaned up when the area is
|
||||
next saved, even if the client sends it back.
|
||||
"""
|
||||
# Seed a stale label via the helper layer, bypassing WS stripping
|
||||
area = area_registry.async_create("mock", labels={"stale_label"})
|
||||
label_registry.async_create("label_1")
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "config/area_registry/update",
|
||||
"area_id": area.id,
|
||||
"labels": labels,
|
||||
}
|
||||
)
|
||||
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg["success"]
|
||||
assert set(msg["result"]["labels"]) == expected_labels
|
||||
assert area_registry.async_get_area(area.id).labels == expected_labels
|
||||
|
||||
|
||||
async def test_update_area_with_same_name(
|
||||
client: MockHAClientWebSocket, area_registry: ar.AreaRegistry
|
||||
) -> None:
|
||||
|
||||
@@ -9,7 +9,7 @@ from pytest_unordered import unordered
|
||||
from homeassistant.components.config import DOMAIN, device_registry
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
from homeassistant.helpers import device_registry as dr, label_registry as lr
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
@@ -216,9 +216,12 @@ async def test_update_device_labels(
|
||||
hass: HomeAssistant,
|
||||
client: MockHAClientWebSocket,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
label_registry: lr.LabelRegistry,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
) -> None:
|
||||
"""Test update entry labels."""
|
||||
label_registry.async_create("label1")
|
||||
label_registry.async_create("label2")
|
||||
entry = MockConfigEntry(title=None)
|
||||
entry.add_to_hass(hass)
|
||||
created_at = datetime.fromisoformat("2024-07-16T13:30:00.900075+00:00")
|
||||
@@ -262,6 +265,53 @@ async def test_update_device_labels(
|
||||
assert getattr(device, key) == value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("labels", "expected_labels"),
|
||||
[
|
||||
pytest.param(["label1", "missing"], {"label1"}, id="strip_unknown"),
|
||||
pytest.param(["label1", "stale_label"], {"label1"}, id="strip_stale_resent"),
|
||||
pytest.param(["stale_label", "missing"], set(), id="strip_all_unknown"),
|
||||
pytest.param([], set(), id="remove_all"),
|
||||
],
|
||||
)
|
||||
async def test_update_device_strips_unknown_labels(
|
||||
hass: HomeAssistant,
|
||||
client: MockHAClientWebSocket,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
label_registry: lr.LabelRegistry,
|
||||
labels: list[str],
|
||||
expected_labels: set[str],
|
||||
) -> None:
|
||||
"""Test labels not in the label registry are stripped on update.
|
||||
|
||||
A stale label already stored on the device is cleaned up when the device
|
||||
is next saved, even if the client sends it back.
|
||||
"""
|
||||
entry = MockConfigEntry(title=None)
|
||||
entry.add_to_hass(hass)
|
||||
device = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
identifiers={("bridgeid", "0123")},
|
||||
)
|
||||
# Seed a stale label via the helper layer, bypassing WS stripping
|
||||
device_registry.async_update_device(device.id, labels={"stale_label"})
|
||||
label_registry.async_create("label1")
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "config/device_registry/update",
|
||||
"device_id": device.id,
|
||||
"labels": labels,
|
||||
}
|
||||
)
|
||||
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg["success"]
|
||||
assert set(msg["result"]["labels"]) == expected_labels
|
||||
assert device_registry.async_get(device.id).labels == expected_labels
|
||||
|
||||
|
||||
async def test_remove_config_entry_from_device(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
|
||||
@@ -10,7 +10,11 @@ from pytest_unordered import unordered
|
||||
from homeassistant.components.config import entity_registry
|
||||
from homeassistant.const import ATTR_ICON, EntityCategory
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er
|
||||
from homeassistant.helpers import (
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
label_registry as lr,
|
||||
)
|
||||
from homeassistant.helpers.device_registry import DeviceEntryDisabler
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.entity_registry import (
|
||||
@@ -607,9 +611,14 @@ async def test_get_entities(hass: HomeAssistant, client: MockHAClientWebSocket)
|
||||
|
||||
|
||||
async def test_update_entity(
|
||||
hass: HomeAssistant, client: MockHAClientWebSocket, freezer: FrozenDateTimeFactory
|
||||
hass: HomeAssistant,
|
||||
client: MockHAClientWebSocket,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
label_registry: lr.LabelRegistry,
|
||||
) -> None:
|
||||
"""Test updating entity."""
|
||||
label_registry.async_create("label1")
|
||||
label_registry.async_create("label2")
|
||||
created = datetime.fromisoformat("2024-02-14T12:00:00.900075+00:00")
|
||||
freezer.move_to(created)
|
||||
registry = mock_registry(
|
||||
@@ -999,6 +1008,55 @@ async def test_update_entity(
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("labels", "expected_labels"),
|
||||
[
|
||||
pytest.param(["label1", "missing"], {"label1"}, id="strip_unknown"),
|
||||
pytest.param(["label1", "stale_label"], {"label1"}, id="strip_stale_resent"),
|
||||
pytest.param(["stale_label", "missing"], set(), id="strip_all_unknown"),
|
||||
pytest.param([], set(), id="remove_all"),
|
||||
],
|
||||
)
|
||||
async def test_update_entity_strips_unknown_labels(
|
||||
hass: HomeAssistant,
|
||||
client: MockHAClientWebSocket,
|
||||
label_registry: lr.LabelRegistry,
|
||||
labels: list[str],
|
||||
expected_labels: set[str],
|
||||
) -> None:
|
||||
"""Test labels not in the label registry are stripped on update.
|
||||
|
||||
A stale label already stored on the entity is cleaned up when the entity
|
||||
is next saved, even if the client sends it back.
|
||||
"""
|
||||
registry = mock_registry(
|
||||
hass,
|
||||
{
|
||||
"test_domain.world": RegistryEntryWithDefaults(
|
||||
entity_id="test_domain.world",
|
||||
unique_id="1234",
|
||||
platform="test_platform",
|
||||
labels={"stale_label"}, # not in the label registry
|
||||
)
|
||||
},
|
||||
)
|
||||
label_registry.async_create("label1")
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "config/entity_registry/update",
|
||||
"entity_id": "test_domain.world",
|
||||
"labels": labels,
|
||||
}
|
||||
)
|
||||
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg["success"]
|
||||
assert set(msg["result"]["entity_entry"]["labels"]) == expected_labels
|
||||
assert registry.entities["test_domain.world"].labels == expected_labels
|
||||
|
||||
|
||||
async def test_update_entity_require_restart(
|
||||
hass: HomeAssistant, client: MockHAClientWebSocket, freezer: FrozenDateTimeFactory
|
||||
) -> None:
|
||||
|
||||
@@ -558,3 +558,14 @@ async def test_migration_from_1_1(
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def test_async_get_missing_label_ids(
|
||||
hass: HomeAssistant, label_registry: lr.LabelRegistry
|
||||
) -> None:
|
||||
"""Test getting label ids missing from the registry."""
|
||||
label_registry.async_create("mock")
|
||||
|
||||
assert lr.async_get_missing_label_ids(hass, set()) == set()
|
||||
assert lr.async_get_missing_label_ids(hass, {"mock"}) == set()
|
||||
assert lr.async_get_missing_label_ids(hass, {"mock", "missing"}) == {"missing"}
|
||||
|
||||
Reference in New Issue
Block a user