Add support for condition platforms to provide multiple triggers

This commit is contained in:
Erik
2025-06-23 19:37:26 +02:00
parent eff35e93bd
commit fdb1f99835
2 changed files with 86 additions and 36 deletions

View File

@@ -11,6 +11,7 @@ from homeassistant.const import CONF_CONDITION, SUN_EVENT_SUNRISE, SUN_EVENT_SUN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.condition import ( from homeassistant.helpers.condition import (
Condition,
ConditionCheckerType, ConditionCheckerType,
condition_trace_set_result, condition_trace_set_result,
condition_trace_update_result, condition_trace_update_result,
@@ -37,13 +38,6 @@ _CONDITION_SCHEMA = vol.All(
) )
async def async_validate_condition_config(
hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
return _CONDITION_SCHEMA(config) # type: ignore[no-any-return]
def sun( def sun(
hass: HomeAssistant, hass: HomeAssistant,
before: str | None = None, before: str | None = None,
@@ -128,16 +122,41 @@ def sun(
return True return True
def async_condition_from_config(config: ConfigType) -> ConditionCheckerType: class SunCondition(Condition):
"""Wrap action method with sun based condition.""" """Sun condition."""
before = config.get("before")
after = config.get("after")
before_offset = config.get("before_offset")
after_offset = config.get("after_offset")
@trace_condition_function def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: """Initialize condition."""
"""Validate time based if-condition.""" self._config = config
return sun(hass, before, after, before_offset, after_offset) self._hass = hass
return sun_if @classmethod
async def async_validate_condition_config(
cls, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
return _CONDITION_SCHEMA(config) # type: ignore[no-any-return]
async def async_condition_from_config(self) -> ConditionCheckerType:
"""Wrap action method with sun based condition."""
before = self._config.get("before")
after = self._config.get("after")
before_offset = self._config.get("before_offset")
after_offset = self._config.get("after_offset")
@trace_condition_function
def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Validate time based if-condition."""
return sun(hass, before, after, before_offset, after_offset)
return sun_if
CONDITIONS: dict[str, type[Condition]] = {
"sun": SunCondition,
}
async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]:
"""Return the sun conditions."""
return CONDITIONS

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import abc
import asyncio import asyncio
from collections import deque from collections import deque
from collections.abc import Callable, Container, Generator from collections.abc import Callable, Container, Generator
@@ -75,7 +76,7 @@ ASYNC_FROM_CONFIG_FORMAT = "async_{}_from_config"
FROM_CONFIG_FORMAT = "{}_from_config" FROM_CONFIG_FORMAT = "{}_from_config"
VALIDATE_CONFIG_FORMAT = "{}_validate_config" VALIDATE_CONFIG_FORMAT = "{}_validate_config"
_PLATFORM_ALIASES = { _PLATFORM_ALIASES: dict[str | None, str | None] = {
"and": None, "and": None,
"device": "device_automation", "device": "device_automation",
"not": None, "not": None,
@@ -93,9 +94,32 @@ INPUT_ENTITY_ID = re.compile(
) )
class Condition(abc.ABC):
"""Condition class."""
def __init__(self, hass: HomeAssistant, config: ConfigType) -> None:
"""Initialize condition."""
@classmethod
@abc.abstractmethod
async def async_validate_condition_config(
cls, hass: HomeAssistant, config: ConfigType
) -> ConfigType:
"""Validate config."""
@abc.abstractmethod
async def async_condition_from_config(self) -> ConditionCheckerType:
"""Evaluate state based on configuration."""
class ConditionProtocol(Protocol): class ConditionProtocol(Protocol):
"""Define the format of condition modules.""" """Define the format of condition modules."""
async def async_get_conditions(
self, hass: HomeAssistant
) -> dict[str, type[Condition]]:
"""Return the conditions provided by this integration."""
async def async_validate_condition_config( async def async_validate_condition_config(
self, hass: HomeAssistant, config: ConfigType self, hass: HomeAssistant, config: ConfigType
) -> ConfigType: ) -> ConfigType:
@@ -179,7 +203,9 @@ def trace_condition_function(condition: ConditionCheckerType) -> ConditionChecke
async def _async_get_condition_platform( async def _async_get_condition_platform(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, config: ConfigType
) -> ConditionProtocol | None: ) -> ConditionProtocol | None:
platform = config[CONF_CONDITION] condition_key: str = config[CONF_CONDITION]
platform_and_sub_type = condition_key.partition(".")
platform: str | None = platform_and_sub_type[0]
platform = _PLATFORM_ALIASES.get(platform, platform) platform = _PLATFORM_ALIASES.get(platform, platform)
if platform is None: if platform is None:
return None return None
@@ -187,7 +213,7 @@ async def _async_get_condition_platform(
integration = await async_get_integration(hass, platform) integration = await async_get_integration(hass, platform)
except IntegrationNotFound: except IntegrationNotFound:
raise HomeAssistantError( raise HomeAssistantError(
f'Invalid condition "{platform}" specified {config}' f'Invalid condition "{condition_key}" specified {config}'
) from None ) from None
try: try:
return await integration.async_get_platform("condition") return await integration.async_get_platform("condition")
@@ -205,19 +231,6 @@ async def async_from_config(
Should be run on the event loop. Should be run on the event loop.
""" """
factory: Any = None
platform = await _async_get_condition_platform(hass, config)
if platform is None:
condition = config.get(CONF_CONDITION)
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
if factory:
break
else:
factory = platform.async_condition_from_config
# Check if condition is not enabled # Check if condition is not enabled
if CONF_ENABLED in config: if CONF_ENABLED in config:
enabled = config[CONF_ENABLED] enabled = config[CONF_ENABLED]
@@ -239,6 +252,21 @@ async def async_from_config(
return disabled_condition return disabled_condition
condition: str = config[CONF_CONDITION]
factory: Any = None
platform = await _async_get_condition_platform(hass, config)
if platform is None:
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(sys.modules[__name__], fmt.format(condition), None)
if factory:
break
else:
condition_descriptors = await platform.async_get_conditions(hass)
condition_instance = condition_descriptors[condition](hass, config)
return await condition_instance.async_condition_from_config()
# Check for partials to properly determine if coroutine function # Check for partials to properly determine if coroutine function
check_factory = factory check_factory = factory
while isinstance(check_factory, ft.partial): while isinstance(check_factory, ft.partial):
@@ -936,7 +964,7 @@ async def async_validate_condition_config(
hass: HomeAssistant, config: ConfigType hass: HomeAssistant, config: ConfigType
) -> ConfigType: ) -> ConfigType:
"""Validate config.""" """Validate config."""
condition = config[CONF_CONDITION] condition: str = config[CONF_CONDITION]
if condition in ("and", "not", "or"): if condition in ("and", "not", "or"):
conditions = [] conditions = []
for sub_cond in config["conditions"]: for sub_cond in config["conditions"]:
@@ -947,7 +975,10 @@ async def async_validate_condition_config(
platform = await _async_get_condition_platform(hass, config) platform = await _async_get_condition_platform(hass, config)
if platform is not None: if platform is not None:
return await platform.async_validate_condition_config(hass, config) condition_descriptors = await platform.async_get_conditions(hass)
if not (condition_class := condition_descriptors.get(condition)):
raise vol.Invalid(f"Invalid condition '{condition}' specified")
return await condition_class.async_validate_condition_config(hass, config)
if platform is None and condition in ("numeric_state", "state"): if platform is None and condition in ("numeric_state", "state"):
validator = cast( validator = cast(
Callable[[HomeAssistant, ConfigType], ConfigType], Callable[[HomeAssistant, ConfigType], ConfigType],