diff --git a/homeassistant/components/zwave_js/trigger.py b/homeassistant/components/zwave_js/trigger.py index e934faec70c..bae39e32eff 100644 --- a/homeassistant/components/zwave_js/trigger.py +++ b/homeassistant/components/zwave_js/trigger.py @@ -7,7 +7,7 @@ from homeassistant.helpers.trigger import Trigger from .triggers import event, value_updated -TRIGGERS = { +TRIGGERS: dict[str, type[Trigger]] = { event.PLATFORM_TYPE: event.EventTrigger, value_updated.PLATFORM_TYPE: value_updated.ValueUpdatedTrigger, } diff --git a/homeassistant/components/zwave_js/triggers/event.py b/homeassistant/components/zwave_js/triggers/event.py index 5cecf7096f2..79f034e48d9 100644 --- a/homeassistant/components/zwave_js/triggers/event.py +++ b/homeassistant/components/zwave_js/triggers/event.py @@ -131,131 +131,30 @@ async def async_validate_trigger_config( return config -async def async_attach_trigger( - hass: HomeAssistant, - config: ConfigType, - action: TriggerActionType, - trigger_info: TriggerInfo, - *, - platform_type: str = PLATFORM_TYPE, -) -> CALLBACK_TYPE: - """Listen for state changes based on configuration.""" - dev_reg = dr.async_get(hass) - if config[ATTR_EVENT_SOURCE] == "node" and not async_get_nodes_from_targets( - hass, config, dev_reg=dev_reg - ): - raise ValueError( - f"No nodes found for given {ATTR_DEVICE_ID}s or {ATTR_ENTITY_ID}s." - ) - - event_source = config[ATTR_EVENT_SOURCE] - event_name = config[ATTR_EVENT] - event_data_filter = config.get(ATTR_EVENT_DATA, {}) - - unsubs: list[Callable] = [] - job = HassJob(action) - - trigger_data = trigger_info["trigger_data"] - - @callback - def async_on_event(event_data: dict, device: dr.DeviceEntry | None = None) -> None: - """Handle event.""" - for key, val in event_data_filter.items(): - if key not in event_data: - return - if ( - config[ATTR_PARTIAL_DICT_MATCH] - and isinstance(event_data[key], dict) - and isinstance(event_data_filter[key], dict) - ): - for key2, val2 in event_data_filter[key].items(): - if key2 not in event_data[key] or event_data[key][key2] != val2: - return - continue - if event_data[key] != val: - return - - payload = { - **trigger_data, - CONF_PLATFORM: platform_type, - ATTR_EVENT_SOURCE: event_source, - ATTR_EVENT: event_name, - ATTR_EVENT_DATA: event_data, - } - - primary_desc = f"Z-Wave JS '{event_source}' event '{event_name}' was emitted" - - if device: - device_name = device.name_by_user or device.name - payload[ATTR_DEVICE_ID] = device.id - home_and_node_id = get_home_and_node_id_from_device_entry(device) - assert home_and_node_id - payload[ATTR_NODE_ID] = home_and_node_id[1] - payload["description"] = f"{primary_desc} on {device_name}" - else: - payload["description"] = primary_desc - - payload["description"] = ( - f"{payload['description']} with event data: {event_data}" - ) - - hass.async_run_hass_job(job, {"trigger": payload}) - - @callback - def async_remove() -> None: - """Remove state listeners async.""" - for unsub in unsubs: - unsub() - unsubs.clear() - - @callback - def _create_zwave_listeners() -> None: - """Create Z-Wave JS listeners.""" - async_remove() - # Nodes list can come from different drivers and we will need to listen to - # server connections for all of them. - drivers: set[Driver] = set() - if not (nodes := async_get_nodes_from_targets(hass, config, dev_reg=dev_reg)): - entry_id = config[ATTR_CONFIG_ENTRY_ID] - entry = hass.config_entries.async_get_entry(entry_id) - assert entry - client: Client = entry.runtime_data[DATA_CLIENT] - driver = client.driver - assert driver - drivers.add(driver) - if event_source == "controller": - unsubs.append(driver.controller.on(event_name, async_on_event)) - else: - unsubs.append(driver.on(event_name, async_on_event)) - - for node in nodes: - driver = node.client.driver - assert driver is not None # The node comes from the driver. - drivers.add(driver) - device_identifier = get_device_id(driver, node) - device = dev_reg.async_get_device(identifiers={device_identifier}) - assert device - # We need to store the device for the callback - unsubs.append( - node.on(event_name, functools.partial(async_on_event, device=device)) - ) - unsubs.extend( - async_dispatcher_connect( - hass, - f"{DOMAIN}_{driver.controller.home_id}_connected_to_server", - _create_zwave_listeners, - ) - for driver in drivers - ) - - _create_zwave_listeners() - - return async_remove - - class EventTrigger(Trigger): """Z-Wave JS event trigger.""" + _platform_type = PLATFORM_TYPE + + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + action: TriggerActionType, + trigger_info: TriggerInfo, + ) -> None: + """Initialize trigger.""" + self._config = config + self._hass = hass + self._event_source = config[ATTR_EVENT_SOURCE] + self._event_name = config[ATTR_EVENT] + self._event_data_filter = config.get(ATTR_EVENT_DATA, {}) + + self._unsubs: list[Callable] = [] + self._job = HassJob(action) + + self._trigger_data = trigger_info["trigger_data"] + @classmethod async def async_validate_trigger_config( cls, hass: HomeAssistant, config: ConfigType @@ -272,4 +171,119 @@ class EventTrigger(Trigger): trigger_info: TriggerInfo, ) -> CALLBACK_TYPE: """Attach a trigger.""" - return await async_attach_trigger(hass, config, action, trigger_info) + dev_reg = dr.async_get(hass) + if config[ATTR_EVENT_SOURCE] == "node" and not async_get_nodes_from_targets( + hass, config, dev_reg=dev_reg + ): + raise ValueError( + f"No nodes found for given {ATTR_DEVICE_ID}s or {ATTR_ENTITY_ID}s." + ) + + trigger = cls(hass, config, action, trigger_info) + trigger._create_zwave_listeners() + return trigger._async_remove + + @callback + def _async_on_event( + self, event_data: dict, device: dr.DeviceEntry | None = None + ) -> None: + """Handle event.""" + for key, val in self._event_data_filter.items(): + if key not in event_data: + return + if ( + self._config[ATTR_PARTIAL_DICT_MATCH] + and isinstance(event_data[key], dict) + and isinstance(self._event_data_filter[key], dict) + ): + for key2, val2 in self._event_data_filter[key].items(): + if key2 not in event_data[key] or event_data[key][key2] != val2: + return + continue + if event_data[key] != val: + return + + payload = { + **self._trigger_data, + CONF_PLATFORM: self._platform_type, + ATTR_EVENT_SOURCE: self._event_source, + ATTR_EVENT: self._event_name, + ATTR_EVENT_DATA: event_data, + } + + primary_desc = ( + f"Z-Wave JS '{self._event_source}' event '{self._event_name}' was emitted" + ) + + if device: + device_name = device.name_by_user or device.name + payload[ATTR_DEVICE_ID] = device.id + home_and_node_id = get_home_and_node_id_from_device_entry(device) + assert home_and_node_id + payload[ATTR_NODE_ID] = home_and_node_id[1] + payload["description"] = f"{primary_desc} on {device_name}" + else: + payload["description"] = primary_desc + + payload["description"] = ( + f"{payload['description']} with event data: {event_data}" + ) + + self._hass.async_run_hass_job(self._job, {"trigger": payload}) + + @callback + def _async_remove(self) -> None: + """Remove state listeners async.""" + for unsub in self._unsubs: + unsub() + self._unsubs.clear() + + @callback + def _create_zwave_listeners(self) -> None: + """Create Z-Wave JS listeners.""" + self._async_remove() + # Nodes list can come from different drivers and we will need to listen to + # server connections for all of them. + drivers: set[Driver] = set() + dev_reg = dr.async_get(self._hass) + if not ( + nodes := async_get_nodes_from_targets( + self._hass, self._config, dev_reg=dev_reg + ) + ): + entry_id = self._config[ATTR_CONFIG_ENTRY_ID] + entry = self._hass.config_entries.async_get_entry(entry_id) + assert entry + client: Client = entry.runtime_data[DATA_CLIENT] + driver = client.driver + assert driver + drivers.add(driver) + if self._event_source == "controller": + self._unsubs.append( + driver.controller.on(self._event_name, self._async_on_event) + ) + else: + self._unsubs.append(driver.on(self._event_name, self._async_on_event)) + + for node in nodes: + driver = node.client.driver + assert driver is not None # The node comes from the driver. + drivers.add(driver) + device_identifier = get_device_id(driver, node) + device = dev_reg.async_get_device(identifiers={device_identifier}) + assert device + # We need to store the device for the callback + self._unsubs.append( + node.on( + self._event_name, + functools.partial(self._async_on_event, device=device), + ) + ) + self._unsubs.extend( + async_dispatcher_connect( + self._hass, + f"{DOMAIN}_{driver.controller.home_id}_connected_to_server", + self._create_zwave_listeners, + ) + for driver in drivers + )