mirror of
https://github.com/home-assistant/core.git
synced 2026-05-07 00:56:50 +02:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7ef88f550f | |||
| c283826369 | |||
| 13565f5c94 | |||
| c208d68292 | |||
| 8bdb5e7a3c | |||
| d8d8bb23a5 | |||
| 5a10e105a8 | |||
| 65a68c138c | |||
| 8237c4db12 | |||
| d380ff61a5 | |||
| 0473407d38 | |||
| f7aecb654b | |||
| f7a91721dc | |||
| d200e547e1 |
@@ -4,7 +4,11 @@ from collections.abc import Mapping
|
||||
|
||||
from homeassistant.const import STATE_OFF, STATE_ON
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers.condition import Condition, EntityConditionBase
|
||||
from homeassistant.helpers.condition import (
|
||||
ENTITY_STATE_CONDITION_SCHEMA_ANY_ALL_FOR,
|
||||
Condition,
|
||||
EntityConditionBase,
|
||||
)
|
||||
|
||||
from .const import ATTR_IS_CLOSED, DOMAIN, CoverDeviceClass
|
||||
from .models import CoverDomainSpec
|
||||
@@ -14,6 +18,7 @@ class CoverConditionBase(EntityConditionBase):
|
||||
"""Base condition for cover state checks."""
|
||||
|
||||
_domain_specs: Mapping[str, CoverDomainSpec]
|
||||
_schema = ENTITY_STATE_CONDITION_SCHEMA_ANY_ALL_FOR
|
||||
|
||||
def is_valid_state(self, entity_state: State) -> bool:
|
||||
"""Check if the state matches the expected cover state."""
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
options:
|
||||
- all
|
||||
- any
|
||||
for:
|
||||
required: true
|
||||
default: 00:00:00
|
||||
selector:
|
||||
duration:
|
||||
|
||||
awning_is_closed:
|
||||
fields: *condition_common_fields
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"common": {
|
||||
"condition_behavior_name": "Condition passes if",
|
||||
"condition_for_name": "For at least",
|
||||
"trigger_behavior_name": "Trigger when",
|
||||
"trigger_for_name": "For at least"
|
||||
},
|
||||
@@ -10,6 +11,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::cover::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::cover::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Awning is closed"
|
||||
@@ -19,6 +23,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::cover::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::cover::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Awning is open"
|
||||
@@ -28,6 +35,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::cover::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::cover::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Blind is closed"
|
||||
@@ -37,6 +47,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::cover::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::cover::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Blind is open"
|
||||
@@ -46,6 +59,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::cover::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::cover::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Curtain is closed"
|
||||
@@ -55,6 +71,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::cover::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::cover::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Curtain is open"
|
||||
@@ -64,6 +83,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::cover::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::cover::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Shade is closed"
|
||||
@@ -73,6 +95,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::cover::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::cover::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Shade is open"
|
||||
@@ -82,6 +107,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::cover::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::cover::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Shutter is closed"
|
||||
@@ -91,6 +119,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::cover::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::cover::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Shutter is open"
|
||||
|
||||
@@ -7,11 +7,10 @@ from typing import Any, Protocol
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import CONF_DOMAIN, CONF_OPTIONS
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.condition import (
|
||||
Condition,
|
||||
ConditionChecker,
|
||||
ConditionCheckerType,
|
||||
ConditionConfig,
|
||||
)
|
||||
@@ -54,6 +53,7 @@ class DeviceCondition(Condition):
|
||||
"""Device condition."""
|
||||
|
||||
_config: ConfigType
|
||||
_platform_checker: ConditionCheckerType
|
||||
|
||||
@classmethod
|
||||
async def async_validate_complete_config(
|
||||
@@ -87,20 +87,20 @@ class DeviceCondition(Condition):
|
||||
assert config.options is not None
|
||||
self._config = config.options
|
||||
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
async def async_setup(self) -> None:
|
||||
"""Test a device condition."""
|
||||
platform = await async_get_device_automation_platform(
|
||||
self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION
|
||||
)
|
||||
platform_checker = platform.async_condition_from_config(
|
||||
self._platform_checker = platform.async_condition_from_config(
|
||||
self._hass, self._config
|
||||
)
|
||||
|
||||
def checker(variables: TemplateVarsType = None, **kwargs: Any) -> bool:
|
||||
result = platform_checker(self._hass, variables)
|
||||
return result is not False
|
||||
|
||||
return checker
|
||||
@callback
|
||||
def _async_check(self, variables: TemplateVarsType = None, **kwargs: Any) -> bool:
|
||||
"""Check the condition."""
|
||||
result = self._platform_checker(self._hass, variables)
|
||||
return result is not False
|
||||
|
||||
|
||||
CONDITIONS: dict[str, type[Condition]] = {
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
options:
|
||||
- all
|
||||
- any
|
||||
for:
|
||||
required: true
|
||||
default: 00:00:00
|
||||
selector:
|
||||
duration:
|
||||
|
||||
is_closed:
|
||||
fields: *condition_common_fields
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"common": {
|
||||
"condition_behavior_name": "Condition passes if",
|
||||
"condition_for_name": "For at least",
|
||||
"trigger_behavior_name": "Trigger when",
|
||||
"trigger_for_name": "For at least"
|
||||
},
|
||||
@@ -10,6 +11,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::door::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::door::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Door is closed"
|
||||
@@ -19,6 +23,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::door::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::door::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Door is open"
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
options:
|
||||
- all
|
||||
- any
|
||||
for:
|
||||
required: true
|
||||
default: 00:00:00
|
||||
selector:
|
||||
duration:
|
||||
|
||||
is_closed:
|
||||
fields: *condition_common_fields
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"common": {
|
||||
"condition_behavior_name": "Condition passes if",
|
||||
"condition_for_name": "For at least",
|
||||
"trigger_behavior_name": "Trigger when",
|
||||
"trigger_for_name": "For at least"
|
||||
},
|
||||
@@ -10,6 +11,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::garage_door::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::garage_door::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Garage door is closed"
|
||||
@@ -19,6 +23,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::garage_door::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::garage_door::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Garage door is open"
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
options:
|
||||
- all
|
||||
- any
|
||||
for:
|
||||
required: true
|
||||
default: 00:00:00
|
||||
selector:
|
||||
duration:
|
||||
|
||||
is_closed:
|
||||
fields: *condition_common_fields
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"common": {
|
||||
"condition_behavior_name": "Condition passes if",
|
||||
"condition_for_name": "For at least",
|
||||
"trigger_behavior_name": "Trigger when",
|
||||
"trigger_for_name": "For at least"
|
||||
},
|
||||
@@ -10,6 +11,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::gate::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::gate::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Gate is closed"
|
||||
@@ -19,6 +23,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::gate::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::gate::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Gate is open"
|
||||
|
||||
@@ -8,12 +8,11 @@ from typing import Any, Unpack, cast
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import CONF_OPTIONS, SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
||||
from homeassistant.helpers.condition import (
|
||||
Condition,
|
||||
ConditionChecker,
|
||||
ConditionCheckParams,
|
||||
ConditionConfig,
|
||||
condition_trace_set_result,
|
||||
@@ -151,19 +150,21 @@ class SunCondition(Condition):
|
||||
super().__init__(hass, config)
|
||||
assert config.options is not None
|
||||
self._options = config.options
|
||||
self._before = self._options.get("before")
|
||||
self._after = self._options.get("after")
|
||||
self._before_offset = self._options.get("before_offset")
|
||||
self._after_offset = self._options.get("after_offset")
|
||||
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
"""Wrap action method with sun based condition."""
|
||||
before = self._options.get("before")
|
||||
after = self._options.get("after")
|
||||
before_offset = self._options.get("before_offset")
|
||||
after_offset = self._options.get("after_offset")
|
||||
|
||||
def sun_if(**kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Validate time based if-condition."""
|
||||
return sun(self._hass, before, after, before_offset, after_offset)
|
||||
|
||||
return sun_if
|
||||
@callback
|
||||
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Check the condition."""
|
||||
return sun(
|
||||
self._hass,
|
||||
self._before,
|
||||
self._after,
|
||||
self._before_offset,
|
||||
self._after_offset,
|
||||
)
|
||||
|
||||
|
||||
CONDITIONS: dict[str, type[Condition]] = {
|
||||
|
||||
@@ -1024,10 +1024,11 @@ async def handle_test_condition(
|
||||
# Do static + dynamic validation of the condition
|
||||
config = await async_validate_condition_config(hass, msg["condition"])
|
||||
# Test the condition
|
||||
check_condition = await async_condition_from_config(hass, config)
|
||||
condition = await async_condition_from_config(hass, config)
|
||||
connection.send_result(
|
||||
msg["id"], {"result": check_condition(hass, msg.get("variables"))}
|
||||
msg["id"], {"result": condition.async_check(variables=msg.get("variables"))}
|
||||
)
|
||||
condition.async_unload()
|
||||
|
||||
|
||||
@decorators.websocket_command(
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
options:
|
||||
- all
|
||||
- any
|
||||
for:
|
||||
required: true
|
||||
default: 00:00:00
|
||||
selector:
|
||||
duration:
|
||||
|
||||
is_closed:
|
||||
fields: *condition_common_fields
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"common": {
|
||||
"condition_behavior_name": "Condition passes if",
|
||||
"condition_for_name": "For at least",
|
||||
"trigger_behavior_name": "Trigger when",
|
||||
"trigger_for_name": "For at least"
|
||||
},
|
||||
@@ -10,6 +11,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::window::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::window::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Window is closed"
|
||||
@@ -19,6 +23,9 @@
|
||||
"fields": {
|
||||
"behavior": {
|
||||
"name": "[%key:component::window::common::condition_behavior_name%]"
|
||||
},
|
||||
"for": {
|
||||
"name": "[%key:component::window::common::condition_for_name%]"
|
||||
}
|
||||
},
|
||||
"name": "Window is open"
|
||||
|
||||
@@ -16,13 +16,12 @@ from homeassistant.const import (
|
||||
STATE_UNAVAILABLE,
|
||||
STATE_UNKNOWN,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.core import HomeAssistant, State, callback
|
||||
from homeassistant.exceptions import ConditionErrorContainer, ConditionErrorMessage
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
|
||||
from homeassistant.helpers.condition import (
|
||||
Condition,
|
||||
ConditionChecker,
|
||||
ConditionCheckParams,
|
||||
ConditionConfig,
|
||||
)
|
||||
@@ -117,44 +116,40 @@ class ZoneCondition(Condition):
|
||||
super().__init__(hass, config)
|
||||
assert config.options is not None
|
||||
self._options = config.options
|
||||
self._entity_ids = self._options.get(CONF_ENTITY_ID, [])
|
||||
self._zone_entity_ids = self._options.get(CONF_ZONE, [])
|
||||
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
"""Wrap action method with zone based condition."""
|
||||
entity_ids = self._options.get(CONF_ENTITY_ID, [])
|
||||
zone_entity_ids = self._options.get(CONF_ZONE, [])
|
||||
@callback
|
||||
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Test if condition."""
|
||||
errors = []
|
||||
|
||||
def if_in_zone(**kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Test if condition."""
|
||||
errors = []
|
||||
|
||||
all_ok = True
|
||||
for entity_id in entity_ids:
|
||||
entity_ok = False
|
||||
for zone_entity_id in zone_entity_ids:
|
||||
try:
|
||||
if zone(self._hass, zone_entity_id, entity_id):
|
||||
entity_ok = True
|
||||
except ConditionErrorMessage as ex:
|
||||
errors.append(
|
||||
ConditionErrorMessage(
|
||||
"zone",
|
||||
(
|
||||
f"error matching {entity_id} with {zone_entity_id}:"
|
||||
f" {ex.message}"
|
||||
),
|
||||
)
|
||||
all_ok = True
|
||||
for entity_id in self._entity_ids:
|
||||
entity_ok = False
|
||||
for zone_entity_id in self._zone_entity_ids:
|
||||
try:
|
||||
if zone(self._hass, zone_entity_id, entity_id):
|
||||
entity_ok = True
|
||||
except ConditionErrorMessage as ex:
|
||||
errors.append(
|
||||
ConditionErrorMessage(
|
||||
"zone",
|
||||
(
|
||||
f"error matching {entity_id} with {zone_entity_id}:"
|
||||
f" {ex.message}"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not entity_ok:
|
||||
all_ok = False
|
||||
if not entity_ok:
|
||||
all_ok = False
|
||||
|
||||
# Raise the errors only if no definitive result was found
|
||||
if errors and not all_ok:
|
||||
raise ConditionErrorContainer("zone", errors=errors)
|
||||
# Raise the errors only if no definitive result was found
|
||||
if errors and not all_ok:
|
||||
raise ConditionErrorContainer("zone", errors=errors)
|
||||
|
||||
return all_ok
|
||||
|
||||
return if_in_zone
|
||||
return all_ok
|
||||
|
||||
|
||||
CONDITIONS: dict[str, type[Condition]] = {
|
||||
|
||||
+284
-116
@@ -93,7 +93,12 @@ from .selector import (
|
||||
NumericThresholdType,
|
||||
TargetSelector,
|
||||
)
|
||||
from .target import TargetSelection, async_extract_referenced_entity_ids
|
||||
from .target import (
|
||||
TargetSelection,
|
||||
TargetStateChangedData,
|
||||
async_extract_referenced_entity_ids,
|
||||
async_track_target_selector_state_change_event,
|
||||
)
|
||||
from .template import Template, render_complex
|
||||
from .trace import (
|
||||
TraceElement,
|
||||
@@ -284,10 +289,97 @@ _CONDITION_SCHEMA = _CONDITION_BASE_SCHEMA.extend(
|
||||
)
|
||||
|
||||
|
||||
class Condition(abc.ABC):
|
||||
"""Condition class."""
|
||||
class ConditionChecker(abc.ABC):
|
||||
"""Base class for condition checkers."""
|
||||
|
||||
_hass: HomeAssistant
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize condition checker."""
|
||||
self._hass = hass
|
||||
self._on_unload: list[Callable[[], None]] = []
|
||||
|
||||
def __call__(
|
||||
self, _: HomeAssistant, variables: TemplateVarsType | None = None
|
||||
) -> bool | None:
|
||||
"""Check the condition."""
|
||||
return self.async_check(variables=variables)
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Clean up when the checker is deleted."""
|
||||
try:
|
||||
self.async_unload()
|
||||
except Exception:
|
||||
_LOGGER.exception("Error while unloading condition checker")
|
||||
|
||||
def async_check(
|
||||
self, *, variables: TemplateVarsType | None = None, **kwargs: Any
|
||||
) -> bool | None:
|
||||
"""Check the condition."""
|
||||
with trace_condition(variables):
|
||||
result = self._async_check(variables=variables, **kwargs)
|
||||
condition_trace_update_result(result=result)
|
||||
return result
|
||||
|
||||
async def async_setup(self) -> None:
|
||||
"""Set up the condition checker.
|
||||
|
||||
Intended to be overridden in derived classes that need to do async setup.
|
||||
"""
|
||||
|
||||
@callback
|
||||
def async_on_unload(self, func: Callable[[], None]) -> None:
|
||||
"""Add a function to call when config entry is unloaded."""
|
||||
self._on_unload.append(func)
|
||||
|
||||
def async_unload(self) -> None:
|
||||
"""Clean up any resources held by the checker."""
|
||||
for cb in self._on_unload:
|
||||
cb()
|
||||
self._on_unload.clear()
|
||||
|
||||
@abc.abstractmethod
|
||||
@callback
|
||||
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool | None:
|
||||
"""Check the condition."""
|
||||
|
||||
|
||||
class LegacyConditionChecker(ConditionChecker):
|
||||
"""Condition checker wrapping a legacy condition factory function."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, checker: ConditionCheckerType) -> None:
|
||||
"""Initialize condition checker."""
|
||||
super().__init__(hass)
|
||||
self._checker = checker
|
||||
|
||||
@callback
|
||||
def _async_check(self, variables: TemplateVarsType = None, **kwargs: Any) -> bool:
|
||||
return self._checker(self._hass, variables)
|
||||
|
||||
|
||||
class DisabledConditionChecker(ConditionChecker):
|
||||
"""Condition checker for disabled conditions."""
|
||||
|
||||
@callback
|
||||
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class CompoundConditionChecker(ConditionChecker):
|
||||
"""Base class for compound condition checkers (and/or/not)."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, checks: list[ConditionChecker]) -> None:
|
||||
"""Initialize condition checker."""
|
||||
super().__init__(hass)
|
||||
self._checks = checks
|
||||
|
||||
def async_unload(self) -> None:
|
||||
"""Clean up child conditions."""
|
||||
for check in self._checks:
|
||||
check.async_unload()
|
||||
super().async_unload()
|
||||
|
||||
|
||||
class Condition(ConditionChecker):
|
||||
"""Condition class."""
|
||||
|
||||
@classmethod
|
||||
async def async_validate_complete_config(
|
||||
@@ -323,11 +415,7 @@ class Condition(abc.ABC):
|
||||
|
||||
def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
|
||||
"""Initialize condition."""
|
||||
self._hass = hass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
"""Get the condition checker."""
|
||||
super().__init__(hass)
|
||||
|
||||
|
||||
ATTR_BEHAVIOR: Final = "behavior"
|
||||
@@ -376,14 +464,99 @@ class EntityConditionBase(Condition):
|
||||
if TYPE_CHECKING:
|
||||
assert config.target
|
||||
assert config.options
|
||||
self._target_config = config.target
|
||||
self._target_selection = TargetSelection(config.target)
|
||||
self._behavior = config.options[ATTR_BEHAVIOR]
|
||||
self._duration: timedelta | None = config.options.get(CONF_FOR)
|
||||
if self._behavior == BEHAVIOR_ANY:
|
||||
self._matcher = self._check_any_match_state
|
||||
elif self._behavior == BEHAVIOR_ALL:
|
||||
self._matcher = self._check_all_match_state
|
||||
self._valid_since: dict[str, datetime] = {}
|
||||
|
||||
def entity_filter(self, entities: set[str]) -> set[str]:
|
||||
"""Filter entities matching any of the domain specs."""
|
||||
return filter_by_domain_specs(self._hass, self._domain_specs, entities)
|
||||
|
||||
@property
|
||||
def _needs_duration_tracking(self) -> bool:
|
||||
"""Whether this condition needs active state change tracking for duration.
|
||||
|
||||
Conditions that are true for a single main state value can use
|
||||
state.last_changed directly. Conditions that track attributes or
|
||||
match multiple states need active tracking because last_changed
|
||||
does not capture those transitions.
|
||||
"""
|
||||
return True
|
||||
|
||||
def _prime_valid_since(self, entity_id: str) -> None:
|
||||
"""Prime _valid_since for an entity already in a valid state.
|
||||
|
||||
For state-based conditions (value_source is None), last_changed
|
||||
accurately reflects when the state changed to the current value.
|
||||
For attribute-based conditions, last_changed only tracks main state
|
||||
changes, so we use last_updated which is bumped on any update
|
||||
(state or attributes). This is conservative - the tracked attribute
|
||||
may have held its value longer - but it's the best we can do
|
||||
to avoid false positives.
|
||||
"""
|
||||
if (
|
||||
(_state := self._hass.states.get(entity_id)) is not None
|
||||
and _state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||
and self.is_valid_state(_state)
|
||||
):
|
||||
domain_spec = self._domain_specs[_state.domain]
|
||||
if domain_spec.value_source is None:
|
||||
self._valid_since[entity_id] = _state.last_changed
|
||||
else:
|
||||
self._valid_since[entity_id] = _state.last_updated
|
||||
|
||||
@override
|
||||
async def async_setup(self) -> None:
|
||||
"""Set up state tracking for duration-based conditions."""
|
||||
await super().async_setup()
|
||||
if not self._duration or not self._needs_duration_tracking:
|
||||
return
|
||||
|
||||
@callback
|
||||
def _state_change_listener(
|
||||
data: TargetStateChangedData,
|
||||
) -> None:
|
||||
"""Track when entities enter or leave a valid state."""
|
||||
event = data.state_change_event
|
||||
entity_id = event.data["entity_id"]
|
||||
to_state = event.data["new_state"]
|
||||
|
||||
if (
|
||||
to_state is not None
|
||||
and to_state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||
and self.is_valid_state(to_state)
|
||||
):
|
||||
if entity_id not in self._valid_since:
|
||||
# Prime from the state object — this handles both
|
||||
# genuine transitions and newly tracked entities
|
||||
self._prime_valid_since(entity_id)
|
||||
else:
|
||||
self._valid_since.pop(entity_id, None)
|
||||
|
||||
@callback
|
||||
def _on_entities_update(added: set[str], removed: set[str]) -> None:
|
||||
"""Handle changes to the tracked entity set."""
|
||||
for entity_id in added:
|
||||
self._prime_valid_since(entity_id)
|
||||
for entity_id in removed:
|
||||
self._valid_since.pop(entity_id, None)
|
||||
|
||||
self.async_on_unload(
|
||||
async_track_target_selector_state_change_event(
|
||||
self._hass,
|
||||
self._target_config,
|
||||
_state_change_listener,
|
||||
self.entity_filter,
|
||||
_on_entities_update,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_tracked_value(self, entity_state: State) -> Any:
|
||||
"""Get the tracked value from a state based on the DomainSpec."""
|
||||
domain_spec = self._domain_specs[entity_state.domain]
|
||||
@@ -395,56 +568,58 @@ class EntityConditionBase(Condition):
|
||||
def is_valid_state(self, entity_state: State) -> bool:
|
||||
"""Check if the state matches the expected state(s)."""
|
||||
|
||||
@override
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
"""Get the condition checker."""
|
||||
|
||||
def check_any_match_state(states: list[State]) -> bool:
|
||||
"""Test if any entity matches the state."""
|
||||
if not self._duration:
|
||||
# Skip duration check if duration is not specified or 0
|
||||
return any(self.is_valid_state(state) for state in states)
|
||||
duration = dt_util.utcnow() - self._duration
|
||||
def _check_any_match_state(self, states: list[State]) -> bool:
|
||||
"""Test if any entity matches the state."""
|
||||
if not self._duration:
|
||||
# Skip duration check if duration is not specified or 0
|
||||
return any(self.is_valid_state(state) for state in states)
|
||||
cutoff = dt_util.utcnow() - self._duration
|
||||
if not self._needs_duration_tracking:
|
||||
return any(
|
||||
self.is_valid_state(state) and duration > state.last_changed
|
||||
self.is_valid_state(state) and state.last_changed <= cutoff
|
||||
for state in states
|
||||
)
|
||||
return any(
|
||||
self.is_valid_state(state)
|
||||
and (valid_since := self._valid_since.get(state.entity_id)) is not None
|
||||
and valid_since <= cutoff
|
||||
for state in states
|
||||
)
|
||||
|
||||
def check_all_match_state(states: list[State]) -> bool:
|
||||
"""Test if all entities match the state."""
|
||||
if not self._duration:
|
||||
# Skip duration check if duration is not specified or 0
|
||||
return all(self.is_valid_state(state) for state in states)
|
||||
duration = dt_util.utcnow() - self._duration
|
||||
def _check_all_match_state(self, states: list[State]) -> bool:
|
||||
"""Test if all entities match the state."""
|
||||
if not self._duration:
|
||||
# Skip duration check if duration is not specified or 0
|
||||
return all(self.is_valid_state(state) for state in states)
|
||||
cutoff = dt_util.utcnow() - self._duration
|
||||
if not self._needs_duration_tracking:
|
||||
return all(
|
||||
self.is_valid_state(state) and duration > state.last_changed
|
||||
self.is_valid_state(state) and state.last_changed <= cutoff
|
||||
for state in states
|
||||
)
|
||||
return all(
|
||||
self.is_valid_state(state)
|
||||
and (valid_since := self._valid_since.get(state.entity_id)) is not None
|
||||
and valid_since <= cutoff
|
||||
for state in states
|
||||
)
|
||||
|
||||
matcher: Callable[[list[State]], bool]
|
||||
if self._behavior == BEHAVIOR_ANY:
|
||||
matcher = check_any_match_state
|
||||
elif self._behavior == BEHAVIOR_ALL:
|
||||
matcher = check_all_match_state
|
||||
|
||||
def test_state(**kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Test state condition."""
|
||||
targeted_entities = async_extract_referenced_entity_ids(
|
||||
self._hass, self._target_selection, expand_group=False
|
||||
)
|
||||
referenced_entity_ids = targeted_entities.referenced.union(
|
||||
targeted_entities.indirectly_referenced
|
||||
)
|
||||
filtered_entity_ids = self.entity_filter(referenced_entity_ids)
|
||||
entity_states = [
|
||||
_state
|
||||
for entity_id in filtered_entity_ids
|
||||
if (_state := self._hass.states.get(entity_id))
|
||||
and _state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||
]
|
||||
return matcher(entity_states)
|
||||
|
||||
return test_state
|
||||
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Test state condition."""
|
||||
targeted_entities = async_extract_referenced_entity_ids(
|
||||
self._hass, self._target_selection, expand_group=False
|
||||
)
|
||||
referenced_entity_ids = targeted_entities.referenced.union(
|
||||
targeted_entities.indirectly_referenced
|
||||
)
|
||||
filtered_entity_ids = self.entity_filter(referenced_entity_ids)
|
||||
entity_states = [
|
||||
_state
|
||||
for entity_id in filtered_entity_ids
|
||||
if (_state := self._hass.states.get(entity_id))
|
||||
and _state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
|
||||
]
|
||||
return self._matcher(entity_states)
|
||||
|
||||
|
||||
class EntityStateConditionBase(EntityConditionBase):
|
||||
@@ -452,6 +627,15 @@ class EntityStateConditionBase(EntityConditionBase):
|
||||
|
||||
_states: set[str | bool]
|
||||
|
||||
@property
|
||||
def _needs_duration_tracking(self) -> bool:
|
||||
"""Single-state conditions with no attribute tracking can use last_changed."""
|
||||
if len(self._states) != 1:
|
||||
return True
|
||||
return any(
|
||||
spec.value_source is not None for spec in self._domain_specs.values()
|
||||
)
|
||||
|
||||
def is_valid_state(self, entity_state: State) -> bool:
|
||||
"""Check if the state matches the expected state(s)."""
|
||||
return self._get_tracked_value(entity_state) in self._states
|
||||
@@ -739,13 +923,6 @@ class ConditionCheckParams(TypedDict, total=False):
|
||||
variables: TemplateVarsType
|
||||
|
||||
|
||||
class ConditionChecker(Protocol):
|
||||
"""Protocol for condition checker callable with typed kwargs."""
|
||||
|
||||
def __call__(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Check the condition."""
|
||||
|
||||
|
||||
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool]
|
||||
type ConditionCheckerTypeOptional = Callable[
|
||||
[HomeAssistant, TemplateVarsType], bool | None
|
||||
@@ -869,20 +1046,10 @@ async def _async_get_condition_platform(
|
||||
return platform, platform_module
|
||||
|
||||
|
||||
async def _async_get_checker(condition: Condition) -> ConditionCheckerType:
|
||||
new_checker = await condition.async_get_checker()
|
||||
|
||||
@trace_condition_function
|
||||
def checker(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
||||
return new_checker(variables=variables)
|
||||
|
||||
return checker
|
||||
|
||||
|
||||
async def async_from_config(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
) -> ConditionCheckerTypeOptional:
|
||||
) -> ConditionChecker:
|
||||
"""Turn a condition configuration into a method.
|
||||
|
||||
Should be run on the event loop.
|
||||
@@ -898,15 +1065,7 @@ async def async_from_config(
|
||||
f"Error rendering condition enabled template: {err}"
|
||||
) from err
|
||||
if not enabled:
|
||||
|
||||
@trace_condition_function
|
||||
def disabled_condition(
|
||||
hass: HomeAssistant, variables: TemplateVarsType = None
|
||||
) -> bool | None:
|
||||
"""Condition not enabled, will act as if it didn't exist."""
|
||||
return None
|
||||
|
||||
return disabled_condition
|
||||
return DisabledConditionChecker(hass)
|
||||
|
||||
condition_key: str = config[CONF_CONDITION]
|
||||
factory: Any = None
|
||||
@@ -925,7 +1084,8 @@ async def async_from_config(
|
||||
target=config.get(CONF_TARGET),
|
||||
),
|
||||
)
|
||||
return await _async_get_checker(condition)
|
||||
await condition.async_setup()
|
||||
return condition
|
||||
|
||||
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
|
||||
factory = getattr(sys.modules[__name__], fmt.format(condition_key), None)
|
||||
@@ -939,30 +1099,39 @@ async def async_from_config(
|
||||
check_factory = check_factory.func
|
||||
|
||||
if inspect.iscoroutinefunction(check_factory):
|
||||
return cast(ConditionCheckerType, await factory(hass, config))
|
||||
return cast(ConditionCheckerType, factory(config))
|
||||
checker = await factory(hass, config)
|
||||
else:
|
||||
checker = factory(config)
|
||||
if isinstance(checker, ConditionChecker):
|
||||
return checker
|
||||
return LegacyConditionChecker(hass, cast(ConditionCheckerType, checker))
|
||||
|
||||
|
||||
async def async_and_from_config(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> ConditionCheckerType:
|
||||
) -> ConditionChecker:
|
||||
"""Create multi condition matcher using 'AND'."""
|
||||
checks = [await async_from_config(hass, entry) for entry in config["conditions"]]
|
||||
return AndConditionChecker(hass, checks)
|
||||
|
||||
@trace_condition_function
|
||||
def if_and_condition(
|
||||
hass: HomeAssistant, variables: TemplateVarsType = None
|
||||
) -> bool:
|
||||
|
||||
class AndConditionChecker(CompoundConditionChecker):
|
||||
"""Condition checker for 'and' compound conditions."""
|
||||
|
||||
@callback
|
||||
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Test and condition."""
|
||||
errors = []
|
||||
for index, check in enumerate(checks):
|
||||
for index, check in enumerate(self._checks):
|
||||
try:
|
||||
with trace_path(["conditions", str(index)]):
|
||||
if check(hass, variables) is False:
|
||||
if check(self._hass, **kwargs) is False:
|
||||
return False
|
||||
except ConditionError as ex:
|
||||
errors.append(
|
||||
ConditionErrorIndex("and", index=index, total=len(checks), error=ex)
|
||||
ConditionErrorIndex(
|
||||
"and", index=index, total=len(self._checks), error=ex
|
||||
)
|
||||
)
|
||||
|
||||
# Raise the errors if no check was false
|
||||
@@ -971,29 +1140,32 @@ async def async_and_from_config(
|
||||
|
||||
return True
|
||||
|
||||
return if_and_condition
|
||||
|
||||
|
||||
async def async_or_from_config(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> ConditionCheckerType:
|
||||
) -> ConditionChecker:
|
||||
"""Create multi condition matcher using 'OR'."""
|
||||
checks = [await async_from_config(hass, entry) for entry in config["conditions"]]
|
||||
return OrConditionChecker(hass, checks)
|
||||
|
||||
@trace_condition_function
|
||||
def if_or_condition(
|
||||
hass: HomeAssistant, variables: TemplateVarsType = None
|
||||
) -> bool:
|
||||
|
||||
class OrConditionChecker(CompoundConditionChecker):
|
||||
"""Condition checker for 'or' compound conditions."""
|
||||
|
||||
@callback
|
||||
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Test or condition."""
|
||||
errors = []
|
||||
for index, check in enumerate(checks):
|
||||
for index, check in enumerate(self._checks):
|
||||
try:
|
||||
with trace_path(["conditions", str(index)]):
|
||||
if check(hass, variables) is True:
|
||||
if check(self._hass, **kwargs) is True:
|
||||
return True
|
||||
except ConditionError as ex:
|
||||
errors.append(
|
||||
ConditionErrorIndex("or", index=index, total=len(checks), error=ex)
|
||||
ConditionErrorIndex(
|
||||
"or", index=index, total=len(self._checks), error=ex
|
||||
)
|
||||
)
|
||||
|
||||
# Raise the errors if no check was true
|
||||
@@ -1002,29 +1174,32 @@ async def async_or_from_config(
|
||||
|
||||
return False
|
||||
|
||||
return if_or_condition
|
||||
|
||||
|
||||
async def async_not_from_config(
|
||||
hass: HomeAssistant, config: ConfigType
|
||||
) -> ConditionCheckerType:
|
||||
) -> ConditionChecker:
|
||||
"""Create multi condition matcher using 'NOT'."""
|
||||
checks = [await async_from_config(hass, entry) for entry in config["conditions"]]
|
||||
return NotConditionChecker(hass, checks)
|
||||
|
||||
@trace_condition_function
|
||||
def if_not_condition(
|
||||
hass: HomeAssistant, variables: TemplateVarsType = None
|
||||
) -> bool:
|
||||
|
||||
class NotConditionChecker(CompoundConditionChecker):
|
||||
"""Condition checker for 'not' compound conditions."""
|
||||
|
||||
@callback
|
||||
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
|
||||
"""Test not condition."""
|
||||
errors = []
|
||||
for index, check in enumerate(checks):
|
||||
for index, check in enumerate(self._checks):
|
||||
try:
|
||||
with trace_path(["conditions", str(index)]):
|
||||
if check(hass, variables):
|
||||
if check(self._hass, **kwargs):
|
||||
return False
|
||||
except ConditionError as ex:
|
||||
errors.append(
|
||||
ConditionErrorIndex("not", index=index, total=len(checks), error=ex)
|
||||
ConditionErrorIndex(
|
||||
"not", index=index, total=len(self._checks), error=ex
|
||||
)
|
||||
)
|
||||
|
||||
# Raise the errors if no check was true
|
||||
@@ -1033,8 +1208,6 @@ async def async_not_from_config(
|
||||
|
||||
return True
|
||||
|
||||
return if_not_condition
|
||||
|
||||
|
||||
def numeric_state(
|
||||
hass: HomeAssistant,
|
||||
@@ -1191,7 +1364,6 @@ def async_numeric_state_from_config(config: ConfigType) -> ConditionCheckerType:
|
||||
above = config.get(CONF_ABOVE)
|
||||
value_template = config.get(CONF_VALUE_TEMPLATE)
|
||||
|
||||
@trace_condition_function
|
||||
def if_numeric_state(
|
||||
hass: HomeAssistant, variables: TemplateVarsType = None
|
||||
) -> bool:
|
||||
@@ -1310,7 +1482,6 @@ def state_from_config(config: ConfigType) -> ConditionCheckerType:
|
||||
if not isinstance(req_states, list):
|
||||
req_states = [req_states]
|
||||
|
||||
@trace_condition_function
|
||||
def if_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
||||
"""Test if condition."""
|
||||
errors = []
|
||||
@@ -1372,7 +1543,6 @@ def async_template_from_config(config: ConfigType) -> ConditionCheckerType:
|
||||
"""Wrap action method with state based condition."""
|
||||
value_template = cast(Template, config.get(CONF_VALUE_TEMPLATE))
|
||||
|
||||
@trace_condition_function
|
||||
def template_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
||||
"""Validate template based if-condition."""
|
||||
return async_template(hass, value_template, variables)
|
||||
@@ -1485,7 +1655,6 @@ def time_from_config(config: ConfigType) -> ConditionCheckerType:
|
||||
after = config.get(CONF_AFTER)
|
||||
weekday = config.get(CONF_WEEKDAY)
|
||||
|
||||
@trace_condition_function
|
||||
def time_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
||||
"""Validate time based if-condition."""
|
||||
return time(hass, before, after, weekday)
|
||||
@@ -1499,7 +1668,6 @@ async def async_trigger_from_config(
|
||||
"""Test a trigger condition."""
|
||||
trigger_id = config[CONF_ID]
|
||||
|
||||
@trace_condition_function
|
||||
def trigger_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
|
||||
"""Validate trigger based if-condition."""
|
||||
return (
|
||||
|
||||
@@ -345,14 +345,25 @@ class TargetStateChangeTracker(TargetEntityChangeTracker):
|
||||
target_selection: TargetSelection,
|
||||
action: Callable[[TargetStateChangedData], Any],
|
||||
entity_filter: Callable[[set[str]], set[str]],
|
||||
on_entities_update: Callable[[set[str], set[str]], None] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the state change tracker."""
|
||||
super().__init__(hass, target_selection, entity_filter)
|
||||
self._action = action
|
||||
self._on_entities_update = on_entities_update
|
||||
self._state_change_unsub: CALLBACK_TYPE | None = None
|
||||
self._tracked_entities: set[str] = set()
|
||||
|
||||
def _handle_entities_update(self, tracked_entities: set[str]) -> None:
|
||||
"""Handle the tracked entities."""
|
||||
previous_entities = self._tracked_entities
|
||||
self._tracked_entities = tracked_entities
|
||||
|
||||
if self._on_entities_update is not None:
|
||||
added = tracked_entities - previous_entities
|
||||
removed = previous_entities - tracked_entities
|
||||
if added or removed:
|
||||
self._on_entities_update(added, removed)
|
||||
|
||||
@callback
|
||||
def state_change_listener(event: Event[EventStateChangedData]) -> None:
|
||||
@@ -380,6 +391,7 @@ def async_track_target_selector_state_change_event(
|
||||
target_selector_config: ConfigType,
|
||||
action: Callable[[TargetStateChangedData], Any],
|
||||
entity_filter: Callable[[set[str]], set[str]] = lambda x: x,
|
||||
on_entities_update: Callable[[set[str], set[str]], None] | None = None,
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Track state changes for entities referenced directly or indirectly in a target selector."""
|
||||
target_selection = TargetSelection(target_selector_config)
|
||||
@@ -387,5 +399,7 @@ def async_track_target_selector_state_change_event(
|
||||
raise HomeAssistantError(
|
||||
f"Target selector {target_selector_config} does not have any selectors defined"
|
||||
)
|
||||
tracker = TargetStateChangeTracker(hass, target_selection, action, entity_filter)
|
||||
tracker = TargetStateChangeTracker(
|
||||
hass, target_selection, action, entity_filter, on_entities_update
|
||||
)
|
||||
return tracker.async_setup()
|
||||
|
||||
+572
-13
@@ -2,6 +2,7 @@
|
||||
|
||||
from collections.abc import Mapping
|
||||
from contextlib import AbstractContextManager, nullcontext as does_not_raise
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
import io
|
||||
from typing import Any
|
||||
@@ -22,6 +23,7 @@ from homeassistant.components.sun import DOMAIN as SUN_DOMAIN
|
||||
from homeassistant.components.system_health import DOMAIN as SYSTEM_HEALTH_DOMAIN
|
||||
from homeassistant.const import (
|
||||
ATTR_DEVICE_CLASS,
|
||||
ATTR_LABEL_ID,
|
||||
ATTR_UNIT_OF_MEASUREMENT,
|
||||
CONF_CONDITION,
|
||||
CONF_DEVICE_ID,
|
||||
@@ -42,6 +44,7 @@ from homeassistant.helpers import (
|
||||
condition,
|
||||
config_validation as cv,
|
||||
entity_registry as er,
|
||||
label_registry as lr,
|
||||
trace,
|
||||
)
|
||||
from homeassistant.helpers.automation import (
|
||||
@@ -54,7 +57,6 @@ from homeassistant.helpers.condition import (
|
||||
BEHAVIOR_ANY,
|
||||
CONDITIONS,
|
||||
Condition,
|
||||
ConditionChecker,
|
||||
EntityNumericalConditionWithUnitBase,
|
||||
_async_get_condition_platform,
|
||||
async_validate_condition_config,
|
||||
@@ -63,7 +65,12 @@ from homeassistant.helpers.condition import (
|
||||
make_entity_state_condition,
|
||||
)
|
||||
from homeassistant.helpers.template import Template
|
||||
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
|
||||
from homeassistant.helpers.typing import (
|
||||
UNDEFINED,
|
||||
ConfigType,
|
||||
TemplateVarsType,
|
||||
UndefinedType,
|
||||
)
|
||||
from homeassistant.loader import Integration, async_get_integration
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import dt as dt_util
|
||||
@@ -2150,16 +2157,16 @@ async def test_platform_multiple_conditions(hass: HomeAssistant) -> None:
|
||||
class MockCondition1(MockCondition):
|
||||
"""Mock condition 1."""
|
||||
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
"""Evaluate state based on configuration."""
|
||||
return lambda **kwargs: True
|
||||
def _async_check(self, variables: TemplateVarsType) -> bool:
|
||||
"""Check the condition."""
|
||||
return True
|
||||
|
||||
class MockCondition2(MockCondition):
|
||||
"""Mock condition 2."""
|
||||
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
"""Evaluate state based on configuration."""
|
||||
return lambda **kwargs: False
|
||||
def _async_check(self, variables: TemplateVarsType) -> bool:
|
||||
"""Check the condition."""
|
||||
return False
|
||||
|
||||
async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]:
|
||||
return {
|
||||
@@ -2297,8 +2304,9 @@ async def test_get_condition_platform_registers_conditions(
|
||||
) -> ConfigType:
|
||||
return config
|
||||
|
||||
async def async_get_checker(self) -> ConditionChecker:
|
||||
return lambda **kwargs: True
|
||||
def _async_check(self, variables: TemplateVarsType) -> bool:
|
||||
"""Check the condition."""
|
||||
return True
|
||||
|
||||
async def async_get_conditions(
|
||||
hass: HomeAssistant,
|
||||
@@ -3103,7 +3111,7 @@ async def _setup_numerical_condition(
|
||||
entity_ids: str | list[str],
|
||||
domain_specs: Mapping[str, DomainSpec] | None = None,
|
||||
valid_unit: str | None | UndefinedType = UNDEFINED,
|
||||
) -> condition.ConditionCheckerType:
|
||||
) -> condition.ConditionChecker:
|
||||
"""Set up a numerical condition via a mock platform and return the test."""
|
||||
condition_cls = make_entity_numerical_condition(
|
||||
domain_specs or _DEFAULT_DOMAIN_SPECS, valid_unit
|
||||
@@ -3432,7 +3440,7 @@ async def _setup_numerical_condition_with_unit(
|
||||
domain_specs: Mapping[str, DomainSpec] | None = None,
|
||||
base_unit: str = UnitOfTemperature.CELSIUS,
|
||||
unit_converter: type = TemperatureConverter,
|
||||
) -> condition.ConditionCheckerType:
|
||||
) -> condition.ConditionChecker:
|
||||
"""Set up a numerical condition with unit conversion via a mock platform."""
|
||||
condition_cls = make_entity_numerical_condition_with_unit(
|
||||
domain_specs or _DEFAULT_DOMAIN_SPECS, base_unit, unit_converter
|
||||
@@ -3953,7 +3961,7 @@ async def _setup_state_condition(
|
||||
condition_options: dict[str, Any] | None = None,
|
||||
domain_specs: Mapping[str, DomainSpec] | None = None,
|
||||
support_duration: bool = False,
|
||||
) -> condition.ConditionCheckerType:
|
||||
) -> condition.ConditionChecker:
|
||||
"""Set up a state condition via a mock platform and return the checker."""
|
||||
condition_cls = make_entity_state_condition(
|
||||
domain_specs or _DEFAULT_DOMAIN_SPECS,
|
||||
@@ -4341,3 +4349,554 @@ async def test_state_condition_duration_unavailable_unknown(
|
||||
await hass.async_block_till_done()
|
||||
freezer.tick(timedelta(seconds=11))
|
||||
assert test_all(hass) is False
|
||||
|
||||
|
||||
_ATTR_DOMAIN_SPECS: Mapping[str, DomainSpec] = {
|
||||
"test": DomainSpec(value_source="test_attr")
|
||||
}
|
||||
|
||||
|
||||
async def _setup_attr_state_condition(
|
||||
hass: HomeAssistant,
|
||||
entity_ids: str | list[str],
|
||||
states: str | bool | set[str | bool],
|
||||
condition_options: dict[str, Any] | None = None,
|
||||
) -> condition.ConditionChecker:
|
||||
"""Set up an attribute-based state condition and return the checker."""
|
||||
condition_cls = make_entity_state_condition(
|
||||
_ATTR_DOMAIN_SPECS,
|
||||
states,
|
||||
support_duration=True,
|
||||
)
|
||||
|
||||
async def async_get_conditions(
|
||||
hass: HomeAssistant,
|
||||
) -> dict[str, type[Condition]]:
|
||||
return {"_": condition_cls}
|
||||
|
||||
mock_integration(hass, MockModule("test"))
|
||||
mock_platform(
|
||||
hass, "test.condition", Mock(async_get_conditions=async_get_conditions)
|
||||
)
|
||||
|
||||
if isinstance(entity_ids, str):
|
||||
entity_ids = [entity_ids]
|
||||
|
||||
config: dict[str, Any] = {
|
||||
CONF_CONDITION: "test",
|
||||
CONF_TARGET: {CONF_ENTITY_ID: entity_ids},
|
||||
CONF_OPTIONS: condition_options or {},
|
||||
}
|
||||
|
||||
config = await async_validate_condition_config(hass, config)
|
||||
test = await condition.async_from_config(hass, config)
|
||||
assert test is not None
|
||||
return test
|
||||
|
||||
|
||||
async def test_state_condition_attr_duration_not_met(
|
||||
hass: HomeAssistant, freezer: FrozenDateTimeFactory
|
||||
) -> None:
|
||||
"""Test attribute-based condition with duration: not met yet."""
|
||||
test = await _setup_attr_state_condition(
|
||||
hass,
|
||||
entity_ids="test.entity_1",
|
||||
states={True},
|
||||
condition_options={CONF_FOR: {"seconds": 10}},
|
||||
)
|
||||
|
||||
hass.states.async_set("test.entity_1", STATE_ON, {"test_attr": True})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Just set — duration not met
|
||||
assert test(hass) is False
|
||||
|
||||
freezer.tick(timedelta(seconds=5))
|
||||
assert test(hass) is False
|
||||
|
||||
|
||||
async def test_state_condition_attr_duration_met(
|
||||
hass: HomeAssistant, freezer: FrozenDateTimeFactory
|
||||
) -> None:
|
||||
"""Test attribute-based condition with duration: met after waiting."""
|
||||
test = await _setup_attr_state_condition(
|
||||
hass,
|
||||
entity_ids="test.entity_1",
|
||||
states={True},
|
||||
condition_options={CONF_FOR: {"seconds": 10}},
|
||||
)
|
||||
|
||||
hass.states.async_set("test.entity_1", STATE_ON, {"test_attr": True})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
freezer.tick(timedelta(seconds=11))
|
||||
assert test(hass) is True
|
||||
|
||||
|
||||
async def test_state_condition_attr_duration_reset_on_attr_change(
|
||||
hass: HomeAssistant, freezer: FrozenDateTimeFactory
|
||||
) -> None:
|
||||
"""Test attribute-based condition: timer resets when attribute changes.
|
||||
|
||||
This is the key difference from state-based duration: the tracked value
|
||||
is in an attribute, so state.last_changed does not capture it. The
|
||||
_valid_since tracking in async_setup handles this correctly.
|
||||
"""
|
||||
test = await _setup_attr_state_condition(
|
||||
hass,
|
||||
entity_ids="test.entity_1",
|
||||
states={True},
|
||||
condition_options={CONF_FOR: {"seconds": 10}},
|
||||
)
|
||||
|
||||
# Set attribute to True
|
||||
hass.states.async_set("test.entity_1", STATE_ON, {"test_attr": True})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# After 8s, change attribute to False (state stays the same)
|
||||
freezer.tick(timedelta(seconds=8))
|
||||
hass.states.async_set("test.entity_1", STATE_ON, {"test_attr": False})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Set attribute back to True
|
||||
hass.states.async_set("test.entity_1", STATE_ON, {"test_attr": True})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# 5s after re-set — not enough (timer was reset)
|
||||
freezer.tick(timedelta(seconds=5))
|
||||
assert test(hass) is False
|
||||
|
||||
# 6 more seconds (11 from re-set) — now met
|
||||
freezer.tick(timedelta(seconds=6))
|
||||
assert test(hass) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("behavior", "one_match_expected"),
|
||||
[(BEHAVIOR_ANY, True), (BEHAVIOR_ALL, False)],
|
||||
)
|
||||
async def test_state_condition_attr_duration_behavior(
|
||||
hass: HomeAssistant,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
behavior: str,
|
||||
one_match_expected: bool,
|
||||
) -> None:
|
||||
"""Test attribute-based condition with duration and behavior any/all."""
|
||||
test = await _setup_attr_state_condition(
|
||||
hass,
|
||||
entity_ids=["test.entity_1", "test.entity_2"],
|
||||
states={True},
|
||||
condition_options={ATTR_BEHAVIOR: behavior, CONF_FOR: {"seconds": 10}},
|
||||
)
|
||||
|
||||
hass.states.async_set("test.entity_1", STATE_ON, {"test_attr": True})
|
||||
hass.states.async_set("test.entity_2", STATE_ON, {"test_attr": True})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Both matching but duration not met
|
||||
assert test(hass) is False
|
||||
|
||||
# Advance past duration — both matching long enough
|
||||
freezer.tick(timedelta(seconds=11))
|
||||
assert test(hass) is True
|
||||
|
||||
# Change entity_2 attribute — only one matching for duration
|
||||
hass.states.async_set("test.entity_2", STATE_ON, {"test_attr": False})
|
||||
await hass.async_block_till_done()
|
||||
assert test(hass) is one_match_expected
|
||||
|
||||
|
||||
@dataclass
|
||||
class _AttrInitStep:
|
||||
"""A state update step before the condition is created."""
|
||||
|
||||
state: str
|
||||
attrs: dict[str, Any] = field(default_factory=dict)
|
||||
delay_before: int = 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("steps", "duration", "initially_met"),
|
||||
[
|
||||
# Attribute set to valid 10s ago, no further changes → met (10 >= 5)
|
||||
(
|
||||
[_AttrInitStep(STATE_ON, {"test_attr": True})],
|
||||
10,
|
||||
True,
|
||||
),
|
||||
# Attribute set to valid 3s ago → not met (3 < 5)
|
||||
(
|
||||
[_AttrInitStep(STATE_ON, {"test_attr": True})],
|
||||
3,
|
||||
False,
|
||||
),
|
||||
# Attribute set to valid, then main state changes 2s later
|
||||
# (attribute stays valid). last_updated is bumped by the state change,
|
||||
# so the effective duration is only 2s from the second update → not met
|
||||
(
|
||||
[
|
||||
_AttrInitStep(STATE_ON, {"test_attr": True}),
|
||||
_AttrInitStep(STATE_OFF, {"test_attr": True}, delay_before=8),
|
||||
],
|
||||
2,
|
||||
False,
|
||||
),
|
||||
# Same as above but enough time after the state change → met
|
||||
(
|
||||
[
|
||||
_AttrInitStep(STATE_ON, {"test_attr": True}),
|
||||
_AttrInitStep(STATE_OFF, {"test_attr": True}, delay_before=2),
|
||||
],
|
||||
8,
|
||||
True,
|
||||
),
|
||||
# Attribute was invalid, then set to valid 4s ago → not met (4 < 5)
|
||||
(
|
||||
[
|
||||
_AttrInitStep(STATE_ON, {"test_attr": False}),
|
||||
_AttrInitStep(STATE_ON, {"test_attr": True}, delay_before=6),
|
||||
],
|
||||
4,
|
||||
False,
|
||||
),
|
||||
# Attribute was invalid, then set to valid 6s ago → met (6 >= 5)
|
||||
(
|
||||
[
|
||||
_AttrInitStep(STATE_ON, {"test_attr": False}),
|
||||
_AttrInitStep(STATE_ON, {"test_attr": True}, delay_before=4),
|
||||
],
|
||||
6,
|
||||
True,
|
||||
),
|
||||
# Attribute valid → invalid → valid 3s ago → not met (3 < 5)
|
||||
(
|
||||
[
|
||||
_AttrInitStep(STATE_ON, {"test_attr": True}),
|
||||
_AttrInitStep(STATE_ON, {"test_attr": False}, delay_before=5),
|
||||
_AttrInitStep(STATE_ON, {"test_attr": True}, delay_before=2),
|
||||
],
|
||||
3,
|
||||
False,
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"valid_long_enough",
|
||||
"valid_too_short",
|
||||
"state_change_bumps_last_updated_not_met",
|
||||
"state_change_bumps_last_updated_met",
|
||||
"invalid_then_valid_not_met",
|
||||
"invalid_then_valid_met",
|
||||
"valid_invalid_valid_not_met",
|
||||
],
|
||||
)
|
||||
async def test_state_condition_attr_duration_initial_state(
|
||||
hass: HomeAssistant,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
steps: list[_AttrInitStep],
|
||||
duration: int,
|
||||
initially_met: bool,
|
||||
) -> None:
|
||||
"""Test attribute-based condition initialization from existing state.
|
||||
|
||||
The condition uses last_updated (not last_changed) to determine how long
|
||||
an attribute-based condition has been true. This is conservative: when
|
||||
the main state changes but the tracked attribute stays the same,
|
||||
last_updated is bumped and the effective duration resets.
|
||||
"""
|
||||
for step in steps:
|
||||
freezer.tick(timedelta(seconds=step.delay_before))
|
||||
hass.states.async_set("test.entity_1", step.state, step.attrs)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
freezer.tick(timedelta(seconds=duration))
|
||||
test = await _setup_attr_state_condition(
|
||||
hass,
|
||||
entity_ids="test.entity_1",
|
||||
states={True},
|
||||
condition_options={CONF_FOR: {"seconds": 5}},
|
||||
)
|
||||
|
||||
assert test(hass) is initially_met
|
||||
|
||||
|
||||
async def _setup_attr_state_condition_with_target(
|
||||
hass: HomeAssistant,
|
||||
target: dict[str, Any],
|
||||
states: str | bool | set[str | bool],
|
||||
condition_options: dict[str, Any] | None = None,
|
||||
) -> condition.ConditionChecker:
|
||||
"""Set up an attribute-based state condition with a custom target."""
|
||||
condition_cls = make_entity_state_condition(
|
||||
_ATTR_DOMAIN_SPECS,
|
||||
states,
|
||||
support_duration=True,
|
||||
)
|
||||
|
||||
async def async_get_conditions(
|
||||
hass: HomeAssistant,
|
||||
) -> dict[str, type[Condition]]:
|
||||
return {"_": condition_cls}
|
||||
|
||||
mock_integration(hass, MockModule("test"))
|
||||
mock_platform(
|
||||
hass, "test.condition", Mock(async_get_conditions=async_get_conditions)
|
||||
)
|
||||
|
||||
config: dict[str, Any] = {
|
||||
CONF_CONDITION: "test",
|
||||
CONF_TARGET: target,
|
||||
CONF_OPTIONS: condition_options or {},
|
||||
}
|
||||
|
||||
config = await async_validate_condition_config(hass, config)
|
||||
test = await condition.async_from_config(hass, config)
|
||||
assert test is not None
|
||||
return test
|
||||
|
||||
|
||||
async def test_state_condition_attr_duration_entity_added_to_target(
|
||||
hass: HomeAssistant, freezer: FrozenDateTimeFactory
|
||||
) -> None:
|
||||
"""Test that _valid_since is primed when an entity is added to the tracked set.
|
||||
|
||||
When targeting by label, adding a label to an entity should make it
|
||||
tracked, and if it's already in a valid state, its duration should be
|
||||
primed from the state timestamps.
|
||||
"""
|
||||
label_reg = lr.async_get(hass)
|
||||
label = label_reg.async_create("Test Duration")
|
||||
|
||||
entity_reg = er.async_get(hass)
|
||||
entry = entity_reg.async_get_or_create(
|
||||
domain="test", platform="test", unique_id="duration_add"
|
||||
)
|
||||
|
||||
# Entity starts valid but without the label
|
||||
hass.states.async_set(entry.entity_id, STATE_ON, {"test_attr": True})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Create condition targeting the label
|
||||
test = await _setup_attr_state_condition_with_target(
|
||||
hass,
|
||||
target={ATTR_LABEL_ID: label.label_id},
|
||||
states={True},
|
||||
condition_options={CONF_FOR: {"seconds": 5}},
|
||||
)
|
||||
|
||||
# No entities have the label yet — condition has no entities to check,
|
||||
# behavior "any" with no matching entities returns False
|
||||
assert test(hass) is False
|
||||
|
||||
# Add the label to the entity — entity is already in valid state
|
||||
freezer.tick(timedelta(seconds=1))
|
||||
entity_reg.async_update_entity(entry.entity_id, labels={label.label_id})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Just added — duration not met yet
|
||||
assert test(hass) is False
|
||||
|
||||
# Wait past the duration from when entity was last_updated
|
||||
freezer.tick(timedelta(seconds=5))
|
||||
assert test(hass) is True
|
||||
|
||||
|
||||
async def test_state_condition_attr_duration_entity_removed_from_target(
|
||||
hass: HomeAssistant, freezer: FrozenDateTimeFactory
|
||||
) -> None:
|
||||
"""Test that _valid_since is evicted when an entity is removed from the tracked set."""
|
||||
label_reg = lr.async_get(hass)
|
||||
label = label_reg.async_create("Test Duration Remove")
|
||||
|
||||
entity_reg = er.async_get(hass)
|
||||
entry1 = entity_reg.async_get_or_create(
|
||||
domain="test", platform="test", unique_id="duration_remove_1"
|
||||
)
|
||||
entry2 = entity_reg.async_get_or_create(
|
||||
domain="test", platform="test", unique_id="duration_remove_2"
|
||||
)
|
||||
# Both entities start with the label
|
||||
entity_reg.async_update_entity(entry1.entity_id, labels={label.label_id})
|
||||
entity_reg.async_update_entity(entry2.entity_id, labels={label.label_id})
|
||||
|
||||
# Both entities in valid state
|
||||
hass.states.async_set(entry1.entity_id, STATE_ON, {"test_attr": True})
|
||||
hass.states.async_set(entry2.entity_id, STATE_ON, {"test_attr": True})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
test = await _setup_attr_state_condition_with_target(
|
||||
hass,
|
||||
target={ATTR_LABEL_ID: label.label_id},
|
||||
states={True},
|
||||
condition_options={
|
||||
ATTR_BEHAVIOR: BEHAVIOR_ALL,
|
||||
CONF_FOR: {"seconds": 5},
|
||||
},
|
||||
)
|
||||
|
||||
# Wait past duration — both valid
|
||||
freezer.tick(timedelta(seconds=6))
|
||||
assert test(hass) is True
|
||||
|
||||
# Remove label from entry2
|
||||
entity_reg.async_update_entity(entry2.entity_id, labels=set())
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Condition should still be True — only entry1 is tracked now, and it's valid
|
||||
assert test(hass) is True
|
||||
|
||||
# Now remove label from entry1 too
|
||||
entity_reg.async_update_entity(entry1.entity_id, labels=set())
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# No entities tracked — "all" with empty set is vacuously True
|
||||
assert test(hass) is True
|
||||
|
||||
# Change entry1 to invalid state and re-add its label
|
||||
hass.states.async_set(entry1.entity_id, STATE_ON, {"test_attr": False})
|
||||
await hass.async_block_till_done()
|
||||
entity_reg.async_update_entity(entry1.entity_id, labels={label.label_id})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# entry1 is now tracked again but invalid — "all" fails
|
||||
freezer.tick(timedelta(seconds=10))
|
||||
assert test(hass) is False
|
||||
|
||||
|
||||
async def test_state_condition_attr_duration_entity_added_then_state_changes(
|
||||
hass: HomeAssistant, freezer: FrozenDateTimeFactory
|
||||
) -> None:
|
||||
"""Test that a newly added entity's state changes are properly tracked."""
|
||||
label_reg = lr.async_get(hass)
|
||||
label = label_reg.async_create("Test Duration Track")
|
||||
|
||||
entity_reg = er.async_get(hass)
|
||||
entry = entity_reg.async_get_or_create(
|
||||
domain="test", platform="test", unique_id="duration_track"
|
||||
)
|
||||
|
||||
# Entity starts in invalid state
|
||||
hass.states.async_set(entry.entity_id, STATE_ON, {"test_attr": False})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Create condition targeting the label
|
||||
test = await _setup_attr_state_condition_with_target(
|
||||
hass,
|
||||
target={ATTR_LABEL_ID: label.label_id},
|
||||
states={True},
|
||||
condition_options={CONF_FOR: {"seconds": 5}},
|
||||
)
|
||||
|
||||
# Add the label — entity is invalid, so no priming
|
||||
entity_reg.async_update_entity(entry.entity_id, labels={label.label_id})
|
||||
await hass.async_block_till_done()
|
||||
assert test(hass) is False
|
||||
|
||||
# Now change to valid state
|
||||
freezer.tick(timedelta(seconds=1))
|
||||
hass.states.async_set(entry.entity_id, STATE_ON, {"test_attr": True})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Just became valid — not long enough
|
||||
freezer.tick(timedelta(seconds=3))
|
||||
assert test(hass) is False
|
||||
|
||||
# Now past the duration
|
||||
freezer.tick(timedelta(seconds=3))
|
||||
assert test(hass) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"compound_type",
|
||||
["and", "or", "not"],
|
||||
)
|
||||
async def test_compound_condition_forwards_async_unload(
|
||||
hass: HomeAssistant, compound_type: str
|
||||
) -> None:
|
||||
"""Test that and/or/not compound conditions forward async_unload to children."""
|
||||
config = {
|
||||
"condition": compound_type,
|
||||
"conditions": [
|
||||
{
|
||||
"condition": "state",
|
||||
"entity_id": "test.entity_1",
|
||||
"state": STATE_ON,
|
||||
},
|
||||
{
|
||||
"condition": "state",
|
||||
"entity_id": "test.entity_2",
|
||||
"state": STATE_ON,
|
||||
},
|
||||
],
|
||||
}
|
||||
config = cv.CONDITION_SCHEMA(config)
|
||||
config = await condition.async_validate_condition_config(hass, config)
|
||||
test = await condition.async_from_config(hass, config)
|
||||
|
||||
# The compound checker should hold child checkers
|
||||
assert hasattr(test, "_checks")
|
||||
assert len(test._checks) == 2
|
||||
|
||||
# Patch async_unload on children to verify forwarding
|
||||
child_unloads = [Mock() for _ in test._checks]
|
||||
for child, mock_unload in zip(test._checks, child_unloads, strict=True):
|
||||
child.async_unload = mock_unload
|
||||
|
||||
test.async_unload()
|
||||
|
||||
for mock_unload in child_unloads:
|
||||
mock_unload.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("outer_type", "inner_type"),
|
||||
[
|
||||
(outer, inner)
|
||||
for outer in ("and", "or", "not")
|
||||
for inner in ("and", "or", "not")
|
||||
],
|
||||
)
|
||||
async def test_nested_compound_condition_forwards_async_unload(
|
||||
hass: HomeAssistant, outer_type: str, inner_type: str
|
||||
) -> None:
|
||||
"""Test that nested compound conditions forward async_unload recursively."""
|
||||
config = {
|
||||
"condition": outer_type,
|
||||
"conditions": [
|
||||
{
|
||||
"condition": inner_type,
|
||||
"conditions": [
|
||||
{
|
||||
"condition": "state",
|
||||
"entity_id": "test.entity_1",
|
||||
"state": STATE_ON,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"condition": "state",
|
||||
"entity_id": "test.entity_2",
|
||||
"state": STATE_ON,
|
||||
},
|
||||
],
|
||||
}
|
||||
config = cv.CONDITION_SCHEMA(config)
|
||||
config = await condition.async_validate_condition_config(hass, config)
|
||||
test = await condition.async_from_config(hass, config)
|
||||
|
||||
# Outer compound with 2 children: an inner compound and a leaf
|
||||
assert len(test._checks) == 2
|
||||
inner_checker = test._checks[0]
|
||||
assert hasattr(inner_checker, "_checks")
|
||||
assert len(inner_checker._checks) == 1
|
||||
|
||||
# Patch the innermost leaf's async_unload
|
||||
innermost_unload = Mock()
|
||||
inner_checker._checks[0].async_unload = innermost_unload
|
||||
|
||||
leaf_unload = Mock()
|
||||
test._checks[1].async_unload = leaf_unload
|
||||
|
||||
test.async_unload()
|
||||
|
||||
innermost_unload.assert_called_once()
|
||||
leaf_unload.assert_called_once()
|
||||
|
||||
@@ -762,3 +762,121 @@ async def test_async_track_target_selector_state_change_event_filter(
|
||||
)
|
||||
|
||||
unsub()
|
||||
|
||||
|
||||
async def test_async_track_target_selector_state_change_event_on_entities_update(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test on_entities_update callback reports added and removed entities."""
|
||||
entity_updates: list[tuple[set[str], set[str]]] = []
|
||||
|
||||
@callback
|
||||
def state_change_callback(event: target.TargetStateChangedData) -> None:
|
||||
"""Handle state change events."""
|
||||
|
||||
@callback
|
||||
def on_entities_update(added: set[str], removed: set[str]) -> None:
|
||||
"""Track entity set changes."""
|
||||
entity_updates.append((added, removed))
|
||||
|
||||
config_entry = MockConfigEntry(domain="test")
|
||||
config_entry.add_to_hass(hass)
|
||||
|
||||
entity_reg = er.async_get(hass)
|
||||
label_reg = lr.async_get(hass)
|
||||
label = label_reg.async_create("Track Test")
|
||||
|
||||
entity_a = entity_reg.async_get_or_create(
|
||||
domain="light", platform="test", unique_id="track_a"
|
||||
)
|
||||
entity_b = entity_reg.async_get_or_create(
|
||||
domain="light", platform="test", unique_id="track_b"
|
||||
)
|
||||
|
||||
# entity_a starts with the label
|
||||
entity_reg.async_update_entity(entity_a.entity_id, labels={label.label_id})
|
||||
|
||||
hass.states.async_set(entity_a.entity_id, STATE_ON)
|
||||
hass.states.async_set(entity_b.entity_id, STATE_ON)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
unsub = target.async_track_target_selector_state_change_event(
|
||||
hass,
|
||||
{ATTR_LABEL_ID: label.label_id},
|
||||
state_change_callback,
|
||||
on_entities_update=on_entities_update,
|
||||
)
|
||||
|
||||
# Initial setup fires on_entities_update with all entities as "added"
|
||||
assert len(entity_updates) == 1
|
||||
assert entity_updates[-1] == ({entity_a.entity_id}, set())
|
||||
entity_updates.clear()
|
||||
|
||||
# Add label to entity_b → added
|
||||
entity_reg.async_update_entity(entity_b.entity_id, labels={label.label_id})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(entity_updates) == 1
|
||||
assert entity_updates[-1] == ({entity_b.entity_id}, set())
|
||||
entity_updates.clear()
|
||||
|
||||
# Remove label from entity_a → removed
|
||||
entity_reg.async_update_entity(entity_a.entity_id, labels=set())
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(entity_updates) == 1
|
||||
assert entity_updates[-1] == (set(), {entity_a.entity_id})
|
||||
entity_updates.clear()
|
||||
|
||||
# Remove label from entity_b → removed
|
||||
entity_reg.async_update_entity(entity_b.entity_id, labels=set())
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(entity_updates) == 1
|
||||
assert entity_updates[-1] == (set(), {entity_b.entity_id})
|
||||
entity_updates.clear()
|
||||
|
||||
# Re-add both labels at once — entity_a first, then entity_b
|
||||
entity_reg.async_update_entity(entity_a.entity_id, labels={label.label_id})
|
||||
await hass.async_block_till_done()
|
||||
entity_reg.async_update_entity(entity_b.entity_id, labels={label.label_id})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(entity_updates) == 2
|
||||
assert entity_updates[0] == ({entity_a.entity_id}, set())
|
||||
assert entity_updates[1] == ({entity_b.entity_id}, set())
|
||||
entity_updates.clear()
|
||||
|
||||
# After unsubscribing, no more callbacks
|
||||
unsub()
|
||||
entity_reg.async_update_entity(entity_a.entity_id, labels=set())
|
||||
await hass.async_block_till_done()
|
||||
assert len(entity_updates) == 0
|
||||
|
||||
|
||||
async def test_async_track_target_selector_no_on_entities_update(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test that on_entities_update is optional and defaults to no callback."""
|
||||
events: list[target.TargetStateChangedData] = []
|
||||
|
||||
@callback
|
||||
def state_change_callback(event: target.TargetStateChangedData) -> None:
|
||||
events.append(event)
|
||||
|
||||
entity_id = "light.test_no_callback"
|
||||
hass.states.async_set(entity_id, STATE_ON)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# No on_entities_update — should work without errors
|
||||
unsub = target.async_track_target_selector_state_change_event(
|
||||
hass,
|
||||
{ATTR_ENTITY_ID: entity_id},
|
||||
state_change_callback,
|
||||
)
|
||||
|
||||
hass.states.async_set(entity_id, STATE_OFF)
|
||||
await hass.async_block_till_done()
|
||||
assert len(events) == 1
|
||||
|
||||
unsub()
|
||||
|
||||
Reference in New Issue
Block a user