Add lookup caching to get_x_for_target (#157888)

This commit is contained in:
Abílio Costa
2025-12-16 12:17:58 +00:00
committed by GitHub
parent 9ba252d8e3
commit 7eecdc87fd
2 changed files with 190 additions and 28 deletions
@@ -33,6 +33,10 @@ FLATTENED_SERVICE_DESCRIPTIONS_CACHE: HassKey[
tuple[dict[str, dict[str, Any]], dict[str, dict[str, Any]]]
] = HassKey("websocket_automation_flat_service_description_cache")
AUTOMATION_COMPONENT_LOOKUP_CACHE: HassKey[
dict[str, tuple[Mapping[str, Any], _AutomationComponentLookupTable]]
] = HassKey("websocket_automation_component_lookup_cache")
@dataclass(slots=True, kw_only=True)
class _EntityFilter:
@@ -107,6 +111,14 @@ class _AutomationComponentLookupData:
)
@dataclass(slots=True, kw_only=True)
class _AutomationComponentLookupTable:
"""Helper class for looking up automation components."""
domain_components: dict[str | None, list[_AutomationComponentLookupData]]
component_count: int
def _get_automation_component_domains(
target_description: dict[str, Any],
) -> set[str | None]:
@@ -138,8 +150,51 @@ def _get_automation_component_domains(
return domains
def _get_automation_component_lookup_table(
hass: HomeAssistant,
component_type: str,
component_descriptions: Mapping[str, Mapping[str, Any] | None],
) -> _AutomationComponentLookupTable:
"""Get a dict of automation components keyed by domain, along with the total number of components.
Returns a cached object if available.
"""
try:
cache = hass.data[AUTOMATION_COMPONENT_LOOKUP_CACHE]
except KeyError:
cache = hass.data[AUTOMATION_COMPONENT_LOOKUP_CACHE] = {}
if (cached := cache.get(component_type)) is not None:
cached_descriptions, cached_lookup = cached
if cached_descriptions is component_descriptions:
_LOGGER.debug(
"Using cached automation component lookup data for %s", component_type
)
return cached_lookup
lookup_table = _AutomationComponentLookupTable(
domain_components={}, component_count=0
)
for component, description in component_descriptions.items():
if description is None or CONF_TARGET not in description:
_LOGGER.debug("Skipping component %s without target description", component)
continue
domains = _get_automation_component_domains(description[CONF_TARGET])
lookup_data = _AutomationComponentLookupData.create(
component, description[CONF_TARGET]
)
for domain in domains:
lookup_table.domain_components.setdefault(domain, []).append(lookup_data)
lookup_table.component_count += 1
cache[component_type] = (component_descriptions, lookup_table)
return lookup_table
def _async_get_automation_components_for_target(
hass: HomeAssistant,
component_type: str,
target_selection: ConfigType,
expand_group: bool,
component_descriptions: Mapping[str, Mapping[str, Any] | None],
@@ -155,27 +210,17 @@ def _async_get_automation_components_for_target(
)
_LOGGER.debug("Extracted entities for lookup: %s", extracted)
# Build lookup structure: domain -> list of trigger/condition/service lookup data
domain_components: dict[str | None, list[_AutomationComponentLookupData]] = {}
component_count = 0
for component, description in component_descriptions.items():
if description is None or CONF_TARGET not in description:
_LOGGER.debug("Skipping component %s without target description", component)
continue
domains = _get_automation_component_domains(description[CONF_TARGET])
lookup_data = _AutomationComponentLookupData.create(
component, description[CONF_TARGET]
)
for domain in domains:
domain_components.setdefault(domain, []).append(lookup_data)
component_count += 1
_LOGGER.debug("Automation components per domain: %s", domain_components)
lookup_table = _get_automation_component_lookup_table(
hass, component_type, component_descriptions
)
_LOGGER.debug(
"Automation components per domain: %s", lookup_table.domain_components
)
entity_infos = entity_sources(hass)
matched_components: set[str] = set()
for entity_id in extracted.referenced | extracted.indirectly_referenced:
if component_count == len(matched_components):
if lookup_table.component_count == len(matched_components):
# All automation components matched already, so we don't need to iterate further
break
@@ -187,7 +232,11 @@ def _async_get_automation_components_for_target(
entity_domain = entity_id.split(".")[0]
entity_integration = entity_info["domain"]
for domain in (entity_domain, entity_integration, None):
for component_data in domain_components.get(domain, []):
if not (
domain_component_data := lookup_table.domain_components.get(domain)
):
continue
for component_data in domain_component_data:
if component_data.component in matched_components:
continue
if component_data.matches(
@@ -204,7 +253,7 @@ async def async_get_triggers_for_target(
"""Get triggers for a target."""
descriptions = await async_get_all_trigger_descriptions(hass)
return _async_get_automation_components_for_target(
hass, target_selector, expand_group, descriptions
hass, "triggers", target_selector, expand_group, descriptions
)
@@ -214,7 +263,7 @@ async def async_get_conditions_for_target(
"""Get conditions for a target."""
descriptions = await async_get_all_condition_descriptions(hass)
return _async_get_automation_components_for_target(
hass, target_selector, expand_group, descriptions
hass, "conditions", target_selector, expand_group, descriptions
)
@@ -247,5 +296,9 @@ async def async_get_services_for_target(
return flattened_descriptions
return _async_get_automation_components_for_target(
hass, target_selector, expand_group, get_flattened_service_descriptions()
hass,
"services",
target_selector,
expand_group,
get_flattened_service_descriptions(),
)