diff --git a/homeassistant/components/sun/condition.py b/homeassistant/components/sun/condition.py index 205f1bb8b5c..f48505b4993 100644 --- a/homeassistant/components/sun/condition.py +++ b/homeassistant/components/sun/condition.py @@ -11,6 +11,7 @@ from homeassistant.const import CONF_CONDITION, SUN_EVENT_SUNRISE, SUN_EVENT_SUN from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv from homeassistant.helpers.condition import ( + Condition, ConditionCheckerType, condition_trace_set_result, condition_trace_update_result, @@ -37,13 +38,6 @@ _CONDITION_SCHEMA = vol.All( ) -async def async_validate_condition_config( - hass: HomeAssistant, config: ConfigType -) -> ConfigType: - """Validate config.""" - return _CONDITION_SCHEMA(config) # type: ignore[no-any-return] - - def sun( hass: HomeAssistant, before: str | None = None, @@ -128,16 +122,41 @@ def sun( return True -def async_condition_from_config(config: ConfigType) -> ConditionCheckerType: - """Wrap action method with sun based condition.""" - before = config.get("before") - after = config.get("after") - before_offset = config.get("before_offset") - after_offset = config.get("after_offset") +class SunCondition(Condition): + """Sun condition.""" - @trace_condition_function - def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: - """Validate time based if-condition.""" - return sun(hass, before, after, before_offset, after_offset) + def __init__(self, hass: HomeAssistant, config: ConfigType) -> None: + """Initialize condition.""" + self._config = config + self._hass = hass - return sun_if + @classmethod + async def async_validate_condition_config( + cls, hass: HomeAssistant, config: ConfigType + ) -> ConfigType: + """Validate config.""" + return _CONDITION_SCHEMA(config) # type: ignore[no-any-return] + + async def async_condition_from_config(self) -> ConditionCheckerType: + """Wrap action method with sun based condition.""" + before = self._config.get("before") + after = self._config.get("after") + before_offset = self._config.get("before_offset") + after_offset = self._config.get("after_offset") + + @trace_condition_function + def sun_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool: + """Validate time based if-condition.""" + return sun(hass, before, after, before_offset, after_offset) + + return sun_if + + +CONDITIONS: dict[str, type[Condition]] = { + "sun": SunCondition, +} + + +async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]: + """Return the sun conditions.""" + return CONDITIONS diff --git a/homeassistant/helpers/condition.py b/homeassistant/helpers/condition.py index fbdf2dce7b1..39ecaa5d7bc 100644 --- a/homeassistant/helpers/condition.py +++ b/homeassistant/helpers/condition.py @@ -2,6 +2,7 @@ from __future__ import annotations +import abc import asyncio from collections import deque from collections.abc import Callable, Container, Generator @@ -75,7 +76,7 @@ ASYNC_FROM_CONFIG_FORMAT = "async_{}_from_config" FROM_CONFIG_FORMAT = "{}_from_config" VALIDATE_CONFIG_FORMAT = "{}_validate_config" -_PLATFORM_ALIASES = { +_PLATFORM_ALIASES: dict[str | None, str | None] = { "and": None, "device": "device_automation", "not": None, @@ -93,9 +94,32 @@ INPUT_ENTITY_ID = re.compile( ) +class Condition(abc.ABC): + """Condition class.""" + + def __init__(self, hass: HomeAssistant, config: ConfigType) -> None: + """Initialize condition.""" + + @classmethod + @abc.abstractmethod + async def async_validate_condition_config( + cls, hass: HomeAssistant, config: ConfigType + ) -> ConfigType: + """Validate config.""" + + @abc.abstractmethod + async def async_condition_from_config(self) -> ConditionCheckerType: + """Evaluate state based on configuration.""" + + class ConditionProtocol(Protocol): """Define the format of condition modules.""" + async def async_get_conditions( + self, hass: HomeAssistant + ) -> dict[str, type[Condition]]: + """Return the conditions provided by this integration.""" + async def async_validate_condition_config( self, hass: HomeAssistant, config: ConfigType ) -> ConfigType: @@ -179,7 +203,9 @@ def trace_condition_function(condition: ConditionCheckerType) -> ConditionChecke async def _async_get_condition_platform( hass: HomeAssistant, config: ConfigType ) -> ConditionProtocol | None: - platform = config[CONF_CONDITION] + condition_key: str = config[CONF_CONDITION] + platform_and_sub_type = condition_key.partition(".") + platform: str | None = platform_and_sub_type[0] platform = _PLATFORM_ALIASES.get(platform, platform) if platform is None: return None @@ -187,7 +213,7 @@ async def _async_get_condition_platform( integration = await async_get_integration(hass, platform) except IntegrationNotFound: raise HomeAssistantError( - f'Invalid condition "{platform}" specified {config}' + f'Invalid condition "{condition_key}" specified {config}' ) from None try: return await integration.async_get_platform("condition") @@ -205,19 +231,6 @@ async def async_from_config( Should be run on the event loop. """ - factory: Any = None - platform = await _async_get_condition_platform(hass, config) - - if platform is None: - condition = config.get(CONF_CONDITION) - for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT): - factory = getattr(sys.modules[__name__], fmt.format(condition), None) - - if factory: - break - else: - factory = platform.async_condition_from_config - # Check if condition is not enabled if CONF_ENABLED in config: enabled = config[CONF_ENABLED] @@ -239,6 +252,21 @@ async def async_from_config( return disabled_condition + condition: str = config[CONF_CONDITION] + factory: Any = None + platform = await _async_get_condition_platform(hass, config) + + if platform is None: + for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT): + factory = getattr(sys.modules[__name__], fmt.format(condition), None) + + if factory: + break + else: + condition_descriptors = await platform.async_get_conditions(hass) + condition_instance = condition_descriptors[condition](hass, config) + return await condition_instance.async_condition_from_config() + # Check for partials to properly determine if coroutine function check_factory = factory while isinstance(check_factory, ft.partial): @@ -936,7 +964,7 @@ async def async_validate_condition_config( hass: HomeAssistant, config: ConfigType ) -> ConfigType: """Validate config.""" - condition = config[CONF_CONDITION] + condition: str = config[CONF_CONDITION] if condition in ("and", "not", "or"): conditions = [] for sub_cond in config["conditions"]: @@ -947,7 +975,10 @@ async def async_validate_condition_config( platform = await _async_get_condition_platform(hass, config) if platform is not None: - return await platform.async_validate_condition_config(hass, config) + condition_descriptors = await platform.async_get_conditions(hass) + if not (condition_class := condition_descriptors.get(condition)): + raise vol.Invalid(f"Invalid condition '{condition}' specified") + return await condition_class.async_validate_condition_config(hass, config) if platform is None and condition in ("numeric_state", "state"): validator = cast( Callable[[HomeAssistant, ConfigType], ConfigType],