From 440212ddcee21ed82a81242f7aa8556fd2b217d1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 5 Feb 2024 16:23:53 -0600 Subject: [PATCH] Reduce dict lookups in entity registry indices (#109712) --- homeassistant/helpers/entity_registry.py | 39 +++++++++++------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index dc0c29dc0b7..f72aece4c70 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -449,9 +449,9 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): super().__init__() self._entry_ids: dict[str, RegistryEntry] = {} self._index: dict[tuple[str, str, str], str] = {} - self._config_entry_id_index: dict[str, list[str]] = {} - self._device_id_index: dict[str, list[str]] = {} - self._area_id_index: dict[str, list[str]] = {} + self._config_entry_id_index: dict[str, list[RegistryEntry]] = {} + self._device_id_index: dict[str, list[RegistryEntry]] = {} + self._area_id_index: dict[str, list[RegistryEntry]] = {} def values(self) -> ValuesView[RegistryEntry]: """Return the underlying values to avoid __iter__ overhead.""" @@ -466,23 +466,23 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): self._entry_ids[entry.id] = entry self._index[(entry.domain, entry.platform, entry.unique_id)] = entry.entity_id if (config_entry_id := entry.config_entry_id) is not None: - self._config_entry_id_index.setdefault(config_entry_id, []).append(key) + self._config_entry_id_index.setdefault(config_entry_id, []).append(entry) if (device_id := entry.device_id) is not None: - self._device_id_index.setdefault(device_id, []).append(key) + self._device_id_index.setdefault(device_id, []).append(entry) if (area_id := entry.area_id) is not None: - self._area_id_index.setdefault(area_id, []).append(key) + self._area_id_index.setdefault(area_id, []).append(entry) def _unindex_entry_value( - self, key: str, value: str, index: dict[str, list[str]] + self, entry: RegistryEntry, value: str, index: dict[str, list[RegistryEntry]] ) -> None: """Unindex an entry value. - key is the entry key + entry is the entry value is the value to unindex such as config_entry_id or device_id. index is the index to unindex from. """ entries = index[value] - entries.remove(key) + entries.remove(entry) if not entries: del index[value] @@ -492,11 +492,13 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): del self._entry_ids[entry.id] del self._index[(entry.domain, entry.platform, entry.unique_id)] if config_entry_id := entry.config_entry_id: - self._unindex_entry_value(key, config_entry_id, self._config_entry_id_index) + self._unindex_entry_value( + entry, config_entry_id, self._config_entry_id_index + ) if device_id := entry.device_id: - self._unindex_entry_value(key, device_id, self._device_id_index) + self._unindex_entry_value(entry, device_id, self._device_id_index) if area_id := entry.area_id: - self._unindex_entry_value(key, area_id, self._area_id_index) + self._unindex_entry_value(entry, area_id, self._area_id_index) def __delitem__(self, key: str) -> None: """Remove an item.""" @@ -515,26 +517,21 @@ class EntityRegistryItems(UserDict[str, RegistryEntry]): self, device_id: str, include_disabled_entities: bool = False ) -> list[RegistryEntry]: """Get entries for device.""" - data = self.data return [ entry - for key in self._device_id_index.get(device_id, ()) - if not (entry := data[key]).disabled_by or include_disabled_entities + for entry in self._device_id_index.get(device_id, ()) + if not entry.disabled_by or include_disabled_entities ] def get_entries_for_config_entry_id( self, config_entry_id: str ) -> list[RegistryEntry]: """Get entries for config entry.""" - data = self.data - return [ - data[key] for key in self._config_entry_id_index.get(config_entry_id, ()) - ] + return list(self._config_entry_id_index.get(config_entry_id, ())) def get_entries_for_area_id(self, area_id: str) -> list[RegistryEntry]: """Get entries for area.""" - data = self.data - return [data[key] for key in self._area_id_index.get(area_id, ())] + return list(self._area_id_index.get(area_id, ())) class EntityRegistry: