diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index f983e230897..f5bbc2b6e60 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -176,6 +176,37 @@ def devices_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]: return list(automation_entity.referenced_devices) +@callback +def automations_with_area(hass: HomeAssistant, area_id: str) -> list[str]: + """Return all automations that reference the area.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + return [ + automation_entity.entity_id + for automation_entity in component.entities + if area_id in automation_entity.referenced_areas + ] + + +@callback +def areas_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]: + """Return all areas in an automation.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + automation_entity = component.get_entity(entity_id) + + if automation_entity is None: + return [] + + return list(automation_entity.referenced_areas) + + async def async_setup(hass, config): """Set up all automations.""" # Local import to avoid circular import @@ -293,6 +324,11 @@ class AutomationEntity(ToggleEntity, RestoreEntity): """Return True if entity is on.""" return self._async_detach_triggers is not None or self._is_enabled + @property + def referenced_areas(self): + """Return a set of referenced areas.""" + return self.action_script.referenced_areas + @property def referenced_devices(self): """Return a set of referenced devices.""" diff --git a/homeassistant/components/script/__init__.py b/homeassistant/components/script/__init__.py index d8af3e4a96d..320033e1013 100644 --- a/homeassistant/components/script/__init__.py +++ b/homeassistant/components/script/__init__.py @@ -165,6 +165,37 @@ def devices_in_script(hass: HomeAssistant, entity_id: str) -> list[str]: return list(script_entity.script.referenced_devices) +@callback +def scripts_with_area(hass: HomeAssistant, area_id: str) -> list[str]: + """Return all scripts that reference the area.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + return [ + script_entity.entity_id + for script_entity in component.entities + if area_id in script_entity.script.referenced_areas + ] + + +@callback +def areas_in_script(hass: HomeAssistant, entity_id: str) -> list[str]: + """Return all areas in a script.""" + if DOMAIN not in hass.data: + return [] + + component = hass.data[DOMAIN] + + script_entity = component.get_entity(entity_id) + + if script_entity is None: + return [] + + return list(script_entity.script.referenced_areas) + + async def async_setup(hass, config): """Load the scripts from the configuration.""" hass.data[DOMAIN] = component = EntityComponent(_LOGGER, DOMAIN, hass) diff --git a/homeassistant/components/search/__init__.py b/homeassistant/components/search/__init__.py index 81e33aa24b5..291ef0b52e2 100644 --- a/homeassistant/components/search/__init__.py +++ b/homeassistant/components/search/__init__.py @@ -38,12 +38,12 @@ async def async_setup(hass: HomeAssistant, config: dict): vol.Required("item_id"): str, } ) -async def websocket_search_related(hass, connection, msg): +def websocket_search_related(hass, connection, msg): """Handle search.""" searcher = Searcher( hass, - await device_registry.async_get_registry(hass), - await entity_registry.async_get_registry(hass), + device_registry.async_get(hass), + entity_registry.async_get(hass), ) connection.send_result( msg["id"], searcher.async_search(msg["item_type"], msg["item_id"]) @@ -127,6 +127,12 @@ class Searcher: ): self._add_or_resolve("entity", entity_entry.entity_id) + for entity_id in script.scripts_with_area(self.hass, area_id): + self._add_or_resolve("entity", entity_id) + + for entity_id in automation.automations_with_area(self.hass, area_id): + self._add_or_resolve("entity", entity_id) + @callback def _resolve_device(self, device_id) -> None: """Resolve a device.""" @@ -198,6 +204,9 @@ class Searcher: for device in automation.devices_in_automation(self.hass, automation_entity_id): self._add_or_resolve("device", device) + for area in automation.areas_in_automation(self.hass, automation_entity_id): + self._add_or_resolve("area", area) + @callback def _resolve_script(self, script_entity_id) -> None: """Resolve a script. @@ -210,6 +219,9 @@ class Searcher: for device in script.devices_in_script(self.hass, script_entity_id): self._add_or_resolve("device", device) + for area in script.areas_in_script(self.hass, script_entity_id): + self._add_or_resolve("area", area) + @callback def _resolve_group(self, group_entity_id) -> None: """Resolve a group. diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 7396925db39..e342f0ff9a8 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -17,6 +17,7 @@ from homeassistant import exceptions from homeassistant.components import device_automation, scene from homeassistant.components.logger import LOGSEVERITY from homeassistant.const import ( + ATTR_AREA_ID, ATTR_DEVICE_ID, ATTR_ENTITY_ID, CONF_ALIAS, @@ -900,10 +901,10 @@ def _referenced_extract_ids(data: dict[str, Any], key: str, found: set[str]) -> return if isinstance(item_ids, str): - item_ids = [item_ids] - - for item_id in item_ids: - found.add(item_id) + found.add(item_ids) + else: + for item_id in item_ids: + found.add(item_id) class Script: @@ -970,6 +971,7 @@ class Script: self._choose_data: dict[int, dict[str, Any]] = {} self._referenced_entities: set[str] | None = None self._referenced_devices: set[str] | None = None + self._referenced_areas: set[str] | None = None self.variables = variables self._variables_dynamic = template.is_complex(variables) if self._variables_dynamic: @@ -1031,6 +1033,28 @@ class Script: """Return true if the current mode support max.""" return self.script_mode in (SCRIPT_MODE_PARALLEL, SCRIPT_MODE_QUEUED) + @property + def referenced_areas(self): + """Return a set of referenced areas.""" + if self._referenced_areas is not None: + return self._referenced_areas + + referenced: set[str] = set() + + for step in self.sequence: + action = cv.determine_script_action(step) + + if action == cv.SCRIPT_ACTION_CALL_SERVICE: + for data in ( + step.get(CONF_TARGET), + step.get(service.CONF_SERVICE_DATA), + step.get(service.CONF_SERVICE_DATA_TEMPLATE), + ): + _referenced_extract_ids(data, ATTR_AREA_ID, referenced) + + self._referenced_areas = referenced + return referenced + @property def referenced_devices(self): """Return a set of referenced devices.""" @@ -1044,7 +1068,6 @@ class Script: if action == cv.SCRIPT_ACTION_CALL_SERVICE: for data in ( - step, step.get(CONF_TARGET), step.get(service.CONF_SERVICE_DATA), step.get(service.CONF_SERVICE_DATA_TEMPLATE), diff --git a/tests/components/search/test_init.py b/tests/components/search/test_init.py index 57d2c365e71..82935f2b41f 100644 --- a/tests/components/search/test_init.py +++ b/tests/components/search/test_init.py @@ -193,6 +193,10 @@ async def test_search(hass): }, ) + # Ensure automations set up correctly. + assert hass.states.get("automation.wled_entity") is not None + assert hass.states.get("automation.wled_device") is not None + # Explore the graph from every node and make sure we find the same results expected = { "config_entry": {wled_config_entry.entry_id}, @@ -276,6 +280,64 @@ async def test_search(hass): assert searcher.async_search(search_type, search_id) == {} +async def test_area_lookup(hass): + """Test area based lookup.""" + area_reg = ar.async_get(hass) + device_reg = dr.async_get(hass) + entity_reg = er.async_get(hass) + + living_room_area = area_reg.async_create("Living Room") + + await async_setup_component( + hass, + "script", + { + "script": { + "wled": { + "sequence": [ + { + "service": "light.turn_on", + "target": {"area_id": living_room_area.id}, + }, + ] + }, + } + }, + ) + + assert await async_setup_component( + hass, + "automation", + { + "automation": [ + { + "alias": "area_turn_on", + "trigger": {"platform": "template", "value_template": "true"}, + "action": [ + { + "service": "light.turn_on", + "data": { + "area_id": living_room_area.id, + }, + }, + ], + }, + ] + }, + ) + + searcher = search.Searcher(hass, device_reg, entity_reg) + assert searcher.async_search("area", living_room_area.id) == { + "script": {"script.wled"}, + "automation": {"automation.area_turn_on"}, + } + + searcher = search.Searcher(hass, device_reg, entity_reg) + assert searcher.async_search("automation", "automation.area_turn_on") == { + "area": {living_room_area.id}, + } + + async def test_ws_api(hass, hass_ws_client): """Test WS API.""" assert await async_setup_component(hass, "search", {})