mirror of
https://github.com/home-assistant/core.git
synced 2025-08-05 05:35:11 +02:00
Use relative trigger keys (#149846)
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
}
|
||||
},
|
||||
"triggers": {
|
||||
"mqtt": {
|
||||
"_": {
|
||||
"trigger": "mdi:swap-horizontal"
|
||||
}
|
||||
}
|
||||
|
@@ -1285,7 +1285,7 @@
|
||||
}
|
||||
},
|
||||
"triggers": {
|
||||
"mqtt": {
|
||||
"_": {
|
||||
"name": "MQTT",
|
||||
"description": "When a specific message is received on a given MQTT topic.",
|
||||
"description_configured": "When an MQTT message has been received",
|
||||
|
@@ -1,6 +1,6 @@
|
||||
# Describes the format for MQTT triggers
|
||||
|
||||
mqtt:
|
||||
_:
|
||||
fields:
|
||||
payload:
|
||||
example: "on"
|
||||
|
@@ -8,8 +8,8 @@ from homeassistant.helpers.trigger import Trigger
|
||||
from .triggers import event, value_updated
|
||||
|
||||
TRIGGERS = {
|
||||
event.PLATFORM_TYPE: event.EventTrigger,
|
||||
value_updated.PLATFORM_TYPE: value_updated.ValueUpdatedTrigger,
|
||||
event.RELATIVE_PLATFORM_TYPE: event.EventTrigger,
|
||||
value_updated.RELATIVE_PLATFORM_TYPE: value_updated.ValueUpdatedTrigger,
|
||||
}
|
||||
|
||||
|
||||
|
@@ -34,8 +34,11 @@ from ..helpers import (
|
||||
)
|
||||
from .trigger_helpers import async_bypass_dynamic_config_validation
|
||||
|
||||
# Relative platform type should be <SUBMODULE_NAME>
|
||||
RELATIVE_PLATFORM_TYPE = f"{__name__.rsplit('.', maxsplit=1)[-1]}"
|
||||
|
||||
# Platform type should be <DOMAIN>.<SUBMODULE_NAME>
|
||||
PLATFORM_TYPE = f"{DOMAIN}.{__name__.rsplit('.', maxsplit=1)[-1]}"
|
||||
PLATFORM_TYPE = f"{DOMAIN}.{RELATIVE_PLATFORM_TYPE}"
|
||||
|
||||
|
||||
def validate_non_node_event_source(obj: dict) -> dict:
|
||||
|
@@ -37,8 +37,11 @@ from ..const import (
|
||||
from ..helpers import async_get_nodes_from_targets, get_device_id
|
||||
from .trigger_helpers import async_bypass_dynamic_config_validation
|
||||
|
||||
# Relative platform type should be <SUBMODULE_NAME>
|
||||
RELATIVE_PLATFORM_TYPE = f"{__name__.rsplit('.', maxsplit=1)[-1]}"
|
||||
|
||||
# Platform type should be <DOMAIN>.<SUBMODULE_NAME>
|
||||
PLATFORM_TYPE = f"{DOMAIN}.{__name__.rsplit('.', maxsplit=1)[-1]}"
|
||||
PLATFORM_TYPE = f"{DOMAIN}.{RELATIVE_PLATFORM_TYPE}"
|
||||
|
||||
ATTR_FROM = "from"
|
||||
ATTR_TO = "to"
|
||||
|
21
homeassistant/helpers/automation.py
Normal file
21
homeassistant/helpers/automation.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Helpers for automation."""
|
||||
|
||||
|
||||
def get_absolute_description_key(domain: str, key: str) -> str:
|
||||
"""Return the absolute description key."""
|
||||
if not key.startswith("_"):
|
||||
return f"{domain}.{key}"
|
||||
key = key[1:] # Remove leading underscore
|
||||
if not key:
|
||||
return domain
|
||||
return key
|
||||
|
||||
|
||||
def get_relative_description_key(domain: str, key: str) -> str:
|
||||
"""Return the relative description key."""
|
||||
platform, *subtype = key.split(".", 1)
|
||||
if platform != domain:
|
||||
return f"_{key}"
|
||||
if not subtype:
|
||||
return "_"
|
||||
return subtype[0]
|
@@ -644,6 +644,13 @@ def slug(value: Any) -> str:
|
||||
raise vol.Invalid(f"invalid slug {value} (try {slg})")
|
||||
|
||||
|
||||
def underscore_slug(value: Any) -> str:
|
||||
"""Validate value is a valid slug, possibly starting with an underscore."""
|
||||
if value.startswith("_"):
|
||||
return f"_{slug(value[1:])}"
|
||||
return slug(value)
|
||||
|
||||
|
||||
def schema_with_slug_keys(
|
||||
value_schema: dict | Callable, *, slug_validator: Callable[[Any], str] = slug
|
||||
) -> Callable:
|
||||
|
@@ -40,9 +40,9 @@ from homeassistant.loader import (
|
||||
from homeassistant.util.async_ import create_eager_task
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.yaml import load_yaml_dict
|
||||
from homeassistant.util.yaml.loader import JSON_TYPE
|
||||
|
||||
from . import config_validation as cv, selector
|
||||
from .automation import get_absolute_description_key, get_relative_description_key
|
||||
from .integration_platform import async_process_integration_platforms
|
||||
from .selector import TargetSelector
|
||||
from .template import Template
|
||||
@@ -100,7 +100,7 @@ def starts_with_dot(key: str) -> str:
|
||||
_TRIGGERS_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Remove(vol.All(str, starts_with_dot)): object,
|
||||
cv.slug: vol.Any(None, _TRIGGER_SCHEMA),
|
||||
cv.underscore_slug: vol.Any(None, _TRIGGER_SCHEMA),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -139,6 +139,7 @@ async def _register_trigger_platform(
|
||||
|
||||
if hasattr(platform, "async_get_triggers"):
|
||||
for trigger_key in await platform.async_get_triggers(hass):
|
||||
trigger_key = get_absolute_description_key(integration_domain, trigger_key)
|
||||
hass.data[TRIGGERS][trigger_key] = integration_domain
|
||||
new_triggers.add(trigger_key)
|
||||
elif hasattr(platform, "async_validate_trigger_config") or hasattr(
|
||||
@@ -357,9 +358,8 @@ class PluggableAction:
|
||||
|
||||
|
||||
async def _async_get_trigger_platform(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> TriggerProtocol:
|
||||
trigger_key: str = config[CONF_PLATFORM]
|
||||
hass: HomeAssistant, trigger_key: str
|
||||
) -> tuple[str, TriggerProtocol]:
|
||||
platform_and_sub_type = trigger_key.split(".")
|
||||
platform = platform_and_sub_type[0]
|
||||
platform = _PLATFORM_ALIASES.get(platform, platform)
|
||||
@@ -368,7 +368,7 @@ async def _async_get_trigger_platform(
|
||||
except IntegrationNotFound:
|
||||
raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified") from None
|
||||
try:
|
||||
return await integration.async_get_platform("trigger")
|
||||
return platform, await integration.async_get_platform("trigger")
|
||||
except ImportError:
|
||||
raise vol.Invalid(
|
||||
f"Integration '{platform}' does not provide trigger support"
|
||||
@@ -381,11 +381,14 @@ async def async_validate_trigger_config(
|
||||
"""Validate triggers."""
|
||||
config = []
|
||||
for conf in trigger_config:
|
||||
platform = await _async_get_trigger_platform(hass, conf)
|
||||
trigger_key: str = conf[CONF_PLATFORM]
|
||||
platform_domain, platform = await _async_get_trigger_platform(hass, trigger_key)
|
||||
if hasattr(platform, "async_get_triggers"):
|
||||
trigger_descriptors = await platform.async_get_triggers(hass)
|
||||
trigger_key: str = conf[CONF_PLATFORM]
|
||||
if not (trigger := trigger_descriptors.get(trigger_key)):
|
||||
relative_trigger_key = get_relative_description_key(
|
||||
platform_domain, trigger_key
|
||||
)
|
||||
if not (trigger := trigger_descriptors.get(relative_trigger_key)):
|
||||
raise vol.Invalid(f"Invalid trigger '{trigger_key}' specified")
|
||||
conf = await trigger.async_validate_trigger_config(hass, conf)
|
||||
elif hasattr(platform, "async_validate_trigger_config"):
|
||||
@@ -471,7 +474,8 @@ async def async_initialize_triggers(
|
||||
if not enabled:
|
||||
continue
|
||||
|
||||
platform = await _async_get_trigger_platform(hass, conf)
|
||||
trigger_key: str = conf[CONF_PLATFORM]
|
||||
platform_domain, platform = await _async_get_trigger_platform(hass, trigger_key)
|
||||
trigger_id = conf.get(CONF_ID, f"{idx}")
|
||||
trigger_idx = f"{idx}"
|
||||
trigger_alias = conf.get(CONF_ALIAS)
|
||||
@@ -487,7 +491,10 @@ async def async_initialize_triggers(
|
||||
action_wrapper = _trigger_action_wrapper(hass, action, conf)
|
||||
if hasattr(platform, "async_get_triggers"):
|
||||
trigger_descriptors = await platform.async_get_triggers(hass)
|
||||
trigger = trigger_descriptors[conf[CONF_PLATFORM]](hass, conf)
|
||||
relative_trigger_key = get_relative_description_key(
|
||||
platform_domain, trigger_key
|
||||
)
|
||||
trigger = trigger_descriptors[relative_trigger_key](hass, conf)
|
||||
coro = trigger.async_attach_trigger(action_wrapper, info)
|
||||
else:
|
||||
coro = platform.async_attach_trigger(hass, conf, action_wrapper, info)
|
||||
@@ -525,11 +532,11 @@ async def async_initialize_triggers(
|
||||
return remove_triggers
|
||||
|
||||
|
||||
def _load_triggers_file(hass: HomeAssistant, integration: Integration) -> JSON_TYPE:
|
||||
def _load_triggers_file(integration: Integration) -> dict[str, Any]:
|
||||
"""Load triggers file for an integration."""
|
||||
try:
|
||||
return cast(
|
||||
JSON_TYPE,
|
||||
dict[str, Any],
|
||||
_TRIGGERS_SCHEMA(
|
||||
load_yaml_dict(str(integration.file_path / "triggers.yaml"))
|
||||
),
|
||||
@@ -549,11 +556,14 @@ def _load_triggers_file(hass: HomeAssistant, integration: Integration) -> JSON_T
|
||||
|
||||
|
||||
def _load_triggers_files(
|
||||
hass: HomeAssistant, integrations: Iterable[Integration]
|
||||
) -> dict[str, JSON_TYPE]:
|
||||
integrations: Iterable[Integration],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Load trigger files for multiple integrations."""
|
||||
return {
|
||||
integration.domain: _load_triggers_file(hass, integration)
|
||||
integration.domain: {
|
||||
get_absolute_description_key(integration.domain, key): value
|
||||
for key, value in _load_triggers_file(integration).items()
|
||||
}
|
||||
for integration in integrations
|
||||
}
|
||||
|
||||
@@ -574,7 +584,7 @@ async def async_get_all_descriptions(
|
||||
return descriptions_cache
|
||||
|
||||
# Files we loaded for missing descriptions
|
||||
new_triggers_descriptions: dict[str, JSON_TYPE] = {}
|
||||
new_triggers_descriptions: dict[str, dict[str, Any]] = {}
|
||||
# We try to avoid making a copy in the event the cache is good,
|
||||
# but now we must make a copy in case new triggers get added
|
||||
# while we are loading the missing ones so we do not
|
||||
@@ -601,7 +611,7 @@ async def async_get_all_descriptions(
|
||||
|
||||
if integrations:
|
||||
new_triggers_descriptions = await hass.async_add_executor_job(
|
||||
_load_triggers_files, hass, integrations
|
||||
_load_triggers_files, integrations
|
||||
)
|
||||
|
||||
# Make a copy of the old cache and add missing descriptions to it
|
||||
@@ -610,7 +620,7 @@ async def async_get_all_descriptions(
|
||||
domain = triggers[missing_trigger]
|
||||
|
||||
if (
|
||||
yaml_description := new_triggers_descriptions.get(domain, {}).get( # type: ignore[union-attr]
|
||||
yaml_description := new_triggers_descriptions.get(domain, {}).get(
|
||||
missing_trigger
|
||||
)
|
||||
) is None:
|
||||
|
@@ -136,7 +136,7 @@ TRIGGER_ICONS_SCHEMA = cv.schema_with_slug_keys(
|
||||
vol.Optional("trigger"): icon_value_validator,
|
||||
}
|
||||
),
|
||||
slug_validator=translation_key_validator,
|
||||
slug_validator=cv.underscore_slug,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -450,7 +450,7 @@ def gen_strings_schema(config: Config, integration: Integration) -> vol.Schema:
|
||||
slug_validator=translation_key_validator,
|
||||
),
|
||||
},
|
||||
slug_validator=translation_key_validator,
|
||||
slug_validator=cv.underscore_slug,
|
||||
),
|
||||
vol.Optional("conversation"): {
|
||||
vol.Required("agent"): {
|
||||
|
@@ -50,7 +50,7 @@ TRIGGER_SCHEMA = vol.Any(
|
||||
TRIGGERS_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Remove(vol.All(str, trigger.starts_with_dot)): object,
|
||||
cv.slug: TRIGGER_SCHEMA,
|
||||
cv.underscore_slug: TRIGGER_SCHEMA,
|
||||
}
|
||||
)
|
||||
|
||||
|
@@ -806,10 +806,10 @@ async def test_subscribe_triggers(
|
||||
) -> None:
|
||||
"""Test trigger_platforms/subscribe command."""
|
||||
sun_trigger_descriptions = """
|
||||
sun: {}
|
||||
_: {}
|
||||
"""
|
||||
tag_trigger_descriptions = """
|
||||
tag: {}
|
||||
_: {}
|
||||
"""
|
||||
|
||||
def _load_yaml(fname, secrets=None):
|
||||
|
@@ -977,7 +977,7 @@ async def test_zwave_js_event_invalid_config_entry_id(
|
||||
async def test_invalid_trigger_configs(hass: HomeAssistant) -> None:
|
||||
"""Test invalid trigger configs."""
|
||||
with pytest.raises(vol.Invalid):
|
||||
await TRIGGERS[f"{DOMAIN}.event"].async_validate_trigger_config(
|
||||
await TRIGGERS["event"].async_validate_trigger_config(
|
||||
hass,
|
||||
{
|
||||
"platform": f"{DOMAIN}.event",
|
||||
@@ -988,7 +988,7 @@ async def test_invalid_trigger_configs(hass: HomeAssistant) -> None:
|
||||
)
|
||||
|
||||
with pytest.raises(vol.Invalid):
|
||||
await TRIGGERS[f"{DOMAIN}.value_updated"].async_validate_trigger_config(
|
||||
await TRIGGERS["value_updated"].async_validate_trigger_config(
|
||||
hass,
|
||||
{
|
||||
"platform": f"{DOMAIN}.value_updated",
|
||||
@@ -1026,7 +1026,7 @@ async def test_zwave_js_trigger_config_entry_unloaded(
|
||||
await hass.config_entries.async_unload(integration.entry_id)
|
||||
|
||||
# Test full validation for both events
|
||||
assert await TRIGGERS[f"{DOMAIN}.value_updated"].async_validate_trigger_config(
|
||||
assert await TRIGGERS["value_updated"].async_validate_trigger_config(
|
||||
hass,
|
||||
{
|
||||
"platform": f"{DOMAIN}.value_updated",
|
||||
@@ -1036,7 +1036,7 @@ async def test_zwave_js_trigger_config_entry_unloaded(
|
||||
},
|
||||
)
|
||||
|
||||
assert await TRIGGERS[f"{DOMAIN}.event"].async_validate_trigger_config(
|
||||
assert await TRIGGERS["event"].async_validate_trigger_config(
|
||||
hass,
|
||||
{
|
||||
"platform": f"{DOMAIN}.event",
|
||||
|
36
tests/helpers/test_automation.py
Normal file
36
tests/helpers/test_automation.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Test automation helpers."""
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.helpers.automation import (
|
||||
get_absolute_description_key,
|
||||
get_relative_description_key,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("relative_key", "absolute_key"),
|
||||
[
|
||||
("turned_on", "homeassistant.turned_on"),
|
||||
("_", "homeassistant"),
|
||||
("_state", "state"),
|
||||
],
|
||||
)
|
||||
def test_absolute_description_key(relative_key: str, absolute_key: str) -> None:
|
||||
"""Test absolute description key."""
|
||||
DOMAIN = "homeassistant"
|
||||
assert get_absolute_description_key(DOMAIN, relative_key) == absolute_key
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("relative_key", "absolute_key"),
|
||||
[
|
||||
("turned_on", "homeassistant.turned_on"),
|
||||
("_", "homeassistant"),
|
||||
("_state", "state"),
|
||||
],
|
||||
)
|
||||
def test_relative_description_key(relative_key: str, absolute_key: str) -> None:
|
||||
"""Test relative description key."""
|
||||
DOMAIN = "homeassistant"
|
||||
assert get_relative_description_key(DOMAIN, absolute_key) == relative_key
|
@@ -50,7 +50,7 @@ async def test_trigger_subtype(hass: HomeAssistant) -> None:
|
||||
"homeassistant.helpers.trigger.async_get_integration",
|
||||
return_value=MagicMock(async_get_platform=AsyncMock()),
|
||||
) as integration_mock:
|
||||
await _async_get_trigger_platform(hass, {"platform": "test.subtype"})
|
||||
await _async_get_trigger_platform(hass, "test.subtype")
|
||||
assert integration_mock.call_args == call(hass, "test")
|
||||
|
||||
|
||||
@@ -493,8 +493,8 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
|
||||
hass: HomeAssistant,
|
||||
) -> dict[str, type[Trigger]]:
|
||||
return {
|
||||
"test": MockTrigger1,
|
||||
"test.trig_2": MockTrigger2,
|
||||
"_": MockTrigger1,
|
||||
"trig_2": MockTrigger2,
|
||||
}
|
||||
|
||||
mock_integration(hass, MockModule("test"))
|
||||
@@ -534,7 +534,7 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
|
||||
"sun_trigger_descriptions",
|
||||
[
|
||||
"""
|
||||
sun:
|
||||
_:
|
||||
fields:
|
||||
event:
|
||||
example: sunrise
|
||||
@@ -551,7 +551,7 @@ async def test_platform_multiple_triggers(hass: HomeAssistant) -> None:
|
||||
.anchor: &anchor
|
||||
- sunrise
|
||||
- sunset
|
||||
sun:
|
||||
_:
|
||||
fields:
|
||||
event:
|
||||
example: sunrise
|
||||
@@ -569,7 +569,7 @@ async def test_async_get_all_descriptions(
|
||||
) -> None:
|
||||
"""Test async_get_all_descriptions."""
|
||||
tag_trigger_descriptions = """
|
||||
tag:
|
||||
_:
|
||||
fields:
|
||||
entity:
|
||||
selector:
|
||||
@@ -607,7 +607,7 @@ async def test_async_get_all_descriptions(
|
||||
|
||||
# Test we only load triggers.yaml for integrations with triggers,
|
||||
# system_health has no triggers
|
||||
assert proxy_load_triggers_files.mock_calls[0][1][1] == unordered(
|
||||
assert proxy_load_triggers_files.mock_calls[0][1][0] == unordered(
|
||||
[
|
||||
await async_get_integration(hass, DOMAIN_SUN),
|
||||
]
|
||||
@@ -615,7 +615,7 @@ async def test_async_get_all_descriptions(
|
||||
|
||||
# system_health does not have triggers and should not be in descriptions
|
||||
assert descriptions == {
|
||||
DOMAIN_SUN: {
|
||||
"sun": {
|
||||
"fields": {
|
||||
"event": {
|
||||
"example": "sunrise",
|
||||
@@ -650,7 +650,7 @@ async def test_async_get_all_descriptions(
|
||||
new_descriptions = await trigger.async_get_all_descriptions(hass)
|
||||
assert new_descriptions is not descriptions
|
||||
assert new_descriptions == {
|
||||
DOMAIN_SUN: {
|
||||
"sun": {
|
||||
"fields": {
|
||||
"event": {
|
||||
"example": "sunrise",
|
||||
@@ -666,7 +666,7 @@ async def test_async_get_all_descriptions(
|
||||
"offset": {"selector": {"time": {}}},
|
||||
}
|
||||
},
|
||||
DOMAIN_TAG: {
|
||||
"tag": {
|
||||
"fields": {
|
||||
"entity": {
|
||||
"selector": {
|
||||
@@ -736,7 +736,7 @@ async def test_async_get_all_descriptions_with_bad_description(
|
||||
) -> None:
|
||||
"""Test async_get_all_descriptions."""
|
||||
sun_service_descriptions = """
|
||||
sun:
|
||||
_:
|
||||
fields: not_a_dict
|
||||
"""
|
||||
|
||||
@@ -760,7 +760,7 @@ async def test_async_get_all_descriptions_with_bad_description(
|
||||
|
||||
assert (
|
||||
"Unable to parse triggers.yaml for the sun integration: "
|
||||
"expected a dictionary for dictionary value @ data['sun']['fields']"
|
||||
"expected a dictionary for dictionary value @ data['_']['fields']"
|
||||
) in caplog.text
|
||||
|
||||
|
||||
@@ -787,7 +787,7 @@ async def test_subscribe_triggers(
|
||||
) -> None:
|
||||
"""Test trigger.async_subscribe_platform_events."""
|
||||
sun_trigger_descriptions = """
|
||||
sun: {}
|
||||
_: {}
|
||||
"""
|
||||
|
||||
def _load_yaml(fname, secrets=None):
|
||||
|
Reference in New Issue
Block a user