Refactor zwave_js event trigger

This commit is contained in:
Erik
2025-05-14 13:07:58 +02:00
parent a61e0577b1
commit d1f42b6eaa
2 changed files with 138 additions and 124 deletions

View File

@ -7,7 +7,7 @@ from homeassistant.helpers.trigger import Trigger
from .triggers import event, value_updated
TRIGGERS = {
TRIGGERS: dict[str, type[Trigger]] = {
event.PLATFORM_TYPE: event.EventTrigger,
value_updated.PLATFORM_TYPE: value_updated.ValueUpdatedTrigger,
}

View File

@ -131,131 +131,30 @@ async def async_validate_trigger_config(
return config
async def async_attach_trigger(
hass: HomeAssistant,
config: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
*,
platform_type: str = PLATFORM_TYPE,
) -> CALLBACK_TYPE:
"""Listen for state changes based on configuration."""
dev_reg = dr.async_get(hass)
if config[ATTR_EVENT_SOURCE] == "node" and not async_get_nodes_from_targets(
hass, config, dev_reg=dev_reg
):
raise ValueError(
f"No nodes found for given {ATTR_DEVICE_ID}s or {ATTR_ENTITY_ID}s."
)
event_source = config[ATTR_EVENT_SOURCE]
event_name = config[ATTR_EVENT]
event_data_filter = config.get(ATTR_EVENT_DATA, {})
unsubs: list[Callable] = []
job = HassJob(action)
trigger_data = trigger_info["trigger_data"]
@callback
def async_on_event(event_data: dict, device: dr.DeviceEntry | None = None) -> None:
"""Handle event."""
for key, val in event_data_filter.items():
if key not in event_data:
return
if (
config[ATTR_PARTIAL_DICT_MATCH]
and isinstance(event_data[key], dict)
and isinstance(event_data_filter[key], dict)
):
for key2, val2 in event_data_filter[key].items():
if key2 not in event_data[key] or event_data[key][key2] != val2:
return
continue
if event_data[key] != val:
return
payload = {
**trigger_data,
CONF_PLATFORM: platform_type,
ATTR_EVENT_SOURCE: event_source,
ATTR_EVENT: event_name,
ATTR_EVENT_DATA: event_data,
}
primary_desc = f"Z-Wave JS '{event_source}' event '{event_name}' was emitted"
if device:
device_name = device.name_by_user or device.name
payload[ATTR_DEVICE_ID] = device.id
home_and_node_id = get_home_and_node_id_from_device_entry(device)
assert home_and_node_id
payload[ATTR_NODE_ID] = home_and_node_id[1]
payload["description"] = f"{primary_desc} on {device_name}"
else:
payload["description"] = primary_desc
payload["description"] = (
f"{payload['description']} with event data: {event_data}"
)
hass.async_run_hass_job(job, {"trigger": payload})
@callback
def async_remove() -> None:
"""Remove state listeners async."""
for unsub in unsubs:
unsub()
unsubs.clear()
@callback
def _create_zwave_listeners() -> None:
"""Create Z-Wave JS listeners."""
async_remove()
# Nodes list can come from different drivers and we will need to listen to
# server connections for all of them.
drivers: set[Driver] = set()
if not (nodes := async_get_nodes_from_targets(hass, config, dev_reg=dev_reg)):
entry_id = config[ATTR_CONFIG_ENTRY_ID]
entry = hass.config_entries.async_get_entry(entry_id)
assert entry
client: Client = entry.runtime_data[DATA_CLIENT]
driver = client.driver
assert driver
drivers.add(driver)
if event_source == "controller":
unsubs.append(driver.controller.on(event_name, async_on_event))
else:
unsubs.append(driver.on(event_name, async_on_event))
for node in nodes:
driver = node.client.driver
assert driver is not None # The node comes from the driver.
drivers.add(driver)
device_identifier = get_device_id(driver, node)
device = dev_reg.async_get_device(identifiers={device_identifier})
assert device
# We need to store the device for the callback
unsubs.append(
node.on(event_name, functools.partial(async_on_event, device=device))
)
unsubs.extend(
async_dispatcher_connect(
hass,
f"{DOMAIN}_{driver.controller.home_id}_connected_to_server",
_create_zwave_listeners,
)
for driver in drivers
)
_create_zwave_listeners()
return async_remove
class EventTrigger(Trigger):
"""Z-Wave JS event trigger."""
_platform_type = PLATFORM_TYPE
def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
action: TriggerActionType,
trigger_info: TriggerInfo,
) -> None:
"""Initialize trigger."""
self._config = config
self._hass = hass
self._event_source = config[ATTR_EVENT_SOURCE]
self._event_name = config[ATTR_EVENT]
self._event_data_filter = config.get(ATTR_EVENT_DATA, {})
self._unsubs: list[Callable] = []
self._job = HassJob(action)
self._trigger_data = trigger_info["trigger_data"]
@classmethod
async def async_validate_trigger_config(
cls, hass: HomeAssistant, config: ConfigType
@ -272,4 +171,119 @@ class EventTrigger(Trigger):
trigger_info: TriggerInfo,
) -> CALLBACK_TYPE:
"""Attach a trigger."""
return await async_attach_trigger(hass, config, action, trigger_info)
dev_reg = dr.async_get(hass)
if config[ATTR_EVENT_SOURCE] == "node" and not async_get_nodes_from_targets(
hass, config, dev_reg=dev_reg
):
raise ValueError(
f"No nodes found for given {ATTR_DEVICE_ID}s or {ATTR_ENTITY_ID}s."
)
trigger = cls(hass, config, action, trigger_info)
trigger._create_zwave_listeners()
return trigger._async_remove
@callback
def _async_on_event(
self, event_data: dict, device: dr.DeviceEntry | None = None
) -> None:
"""Handle event."""
for key, val in self._event_data_filter.items():
if key not in event_data:
return
if (
self._config[ATTR_PARTIAL_DICT_MATCH]
and isinstance(event_data[key], dict)
and isinstance(self._event_data_filter[key], dict)
):
for key2, val2 in self._event_data_filter[key].items():
if key2 not in event_data[key] or event_data[key][key2] != val2:
return
continue
if event_data[key] != val:
return
payload = {
**self._trigger_data,
CONF_PLATFORM: self._platform_type,
ATTR_EVENT_SOURCE: self._event_source,
ATTR_EVENT: self._event_name,
ATTR_EVENT_DATA: event_data,
}
primary_desc = (
f"Z-Wave JS '{self._event_source}' event '{self._event_name}' was emitted"
)
if device:
device_name = device.name_by_user or device.name
payload[ATTR_DEVICE_ID] = device.id
home_and_node_id = get_home_and_node_id_from_device_entry(device)
assert home_and_node_id
payload[ATTR_NODE_ID] = home_and_node_id[1]
payload["description"] = f"{primary_desc} on {device_name}"
else:
payload["description"] = primary_desc
payload["description"] = (
f"{payload['description']} with event data: {event_data}"
)
self._hass.async_run_hass_job(self._job, {"trigger": payload})
@callback
def _async_remove(self) -> None:
"""Remove state listeners async."""
for unsub in self._unsubs:
unsub()
self._unsubs.clear()
@callback
def _create_zwave_listeners(self) -> None:
"""Create Z-Wave JS listeners."""
self._async_remove()
# Nodes list can come from different drivers and we will need to listen to
# server connections for all of them.
drivers: set[Driver] = set()
dev_reg = dr.async_get(self._hass)
if not (
nodes := async_get_nodes_from_targets(
self._hass, self._config, dev_reg=dev_reg
)
):
entry_id = self._config[ATTR_CONFIG_ENTRY_ID]
entry = self._hass.config_entries.async_get_entry(entry_id)
assert entry
client: Client = entry.runtime_data[DATA_CLIENT]
driver = client.driver
assert driver
drivers.add(driver)
if self._event_source == "controller":
self._unsubs.append(
driver.controller.on(self._event_name, self._async_on_event)
)
else:
self._unsubs.append(driver.on(self._event_name, self._async_on_event))
for node in nodes:
driver = node.client.driver
assert driver is not None # The node comes from the driver.
drivers.add(driver)
device_identifier = get_device_id(driver, node)
device = dev_reg.async_get_device(identifiers={device_identifier})
assert device
# We need to store the device for the callback
self._unsubs.append(
node.on(
self._event_name,
functools.partial(self._async_on_event, device=device),
)
)
self._unsubs.extend(
async_dispatcher_connect(
self._hass,
f"{DOMAIN}_{driver.controller.home_id}_connected_to_server",
self._create_zwave_listeners,
)
for driver in drivers
)