Compare commits

...

14 Commits

Author SHA1 Message Date
Erik 7ef88f550f Adjust compound conditions 2026-04-23 17:17:36 +02:00
Erik c283826369 Update websocket_api.handle_test_condition 2026-04-23 17:15:39 +02:00
Erik 13565f5c94 Adjust EntityConditionBase 2026-04-23 17:15:39 +02:00
Erik c208d68292 Update compound conditions 2026-04-23 17:15:39 +02:00
Erik 8bdb5e7a3c Add duration support to cover conditions 2026-04-23 17:15:39 +02:00
Erik d8d8bb23a5 Add state tracking to EntityConditionBase 2026-04-23 17:15:39 +02:00
Erik 5a10e105a8 Migrate compound conditions to ConditionChecker 2026-04-23 17:15:39 +02:00
Erik 65a68c138c Reintroduce ConditionCheckParams 2026-04-23 17:14:48 +02:00
Erik 8237c4db12 Adjust 2026-04-23 16:52:25 +02:00
Erik d380ff61a5 Address review comments 2026-04-23 16:42:16 +02:00
Erik 0473407d38 Add ConditionChecker.async_on_unload 2026-04-23 16:26:03 +02:00
Erik f7aecb654b Log exceptions in cleanup 2026-04-23 13:48:00 +02:00
Erik f7a91721dc Adjust according to feedback 2026-04-23 07:50:05 +02:00
Erik d200e547e1 Refactor condition API 2026-04-22 13:28:44 +02:00
19 changed files with 1135 additions and 190 deletions
+6 -1
View File
@@ -4,7 +4,11 @@ from collections.abc import Mapping
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.condition import Condition, EntityConditionBase
from homeassistant.helpers.condition import (
ENTITY_STATE_CONDITION_SCHEMA_ANY_ALL_FOR,
Condition,
EntityConditionBase,
)
from .const import ATTR_IS_CLOSED, DOMAIN, CoverDeviceClass
from .models import CoverDomainSpec
@@ -14,6 +18,7 @@ class CoverConditionBase(EntityConditionBase):
"""Base condition for cover state checks."""
_domain_specs: Mapping[str, CoverDomainSpec]
_schema = ENTITY_STATE_CONDITION_SCHEMA_ANY_ALL_FOR
def is_valid_state(self, entity_state: State) -> bool:
"""Check if the state matches the expected cover state."""
@@ -8,6 +8,11 @@
options:
- all
- any
for:
required: true
default: 00:00:00
selector:
duration:
awning_is_closed:
fields: *condition_common_fields
@@ -1,6 +1,7 @@
{
"common": {
"condition_behavior_name": "Condition passes if",
"condition_for_name": "For at least",
"trigger_behavior_name": "Trigger when",
"trigger_for_name": "For at least"
},
@@ -10,6 +11,9 @@
"fields": {
"behavior": {
"name": "[%key:component::cover::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::cover::common::condition_for_name%]"
}
},
"name": "Awning is closed"
@@ -19,6 +23,9 @@
"fields": {
"behavior": {
"name": "[%key:component::cover::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::cover::common::condition_for_name%]"
}
},
"name": "Awning is open"
@@ -28,6 +35,9 @@
"fields": {
"behavior": {
"name": "[%key:component::cover::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::cover::common::condition_for_name%]"
}
},
"name": "Blind is closed"
@@ -37,6 +47,9 @@
"fields": {
"behavior": {
"name": "[%key:component::cover::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::cover::common::condition_for_name%]"
}
},
"name": "Blind is open"
@@ -46,6 +59,9 @@
"fields": {
"behavior": {
"name": "[%key:component::cover::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::cover::common::condition_for_name%]"
}
},
"name": "Curtain is closed"
@@ -55,6 +71,9 @@
"fields": {
"behavior": {
"name": "[%key:component::cover::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::cover::common::condition_for_name%]"
}
},
"name": "Curtain is open"
@@ -64,6 +83,9 @@
"fields": {
"behavior": {
"name": "[%key:component::cover::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::cover::common::condition_for_name%]"
}
},
"name": "Shade is closed"
@@ -73,6 +95,9 @@
"fields": {
"behavior": {
"name": "[%key:component::cover::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::cover::common::condition_for_name%]"
}
},
"name": "Shade is open"
@@ -82,6 +107,9 @@
"fields": {
"behavior": {
"name": "[%key:component::cover::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::cover::common::condition_for_name%]"
}
},
"name": "Shutter is closed"
@@ -91,6 +119,9 @@
"fields": {
"behavior": {
"name": "[%key:component::cover::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::cover::common::condition_for_name%]"
}
},
"name": "Shutter is open"
@@ -7,11 +7,10 @@ from typing import Any, Protocol
import voluptuous as vol
from homeassistant.const import CONF_DOMAIN, CONF_OPTIONS
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.condition import (
Condition,
ConditionChecker,
ConditionCheckerType,
ConditionConfig,
)
@@ -54,6 +53,7 @@ class DeviceCondition(Condition):
"""Device condition."""
_config: ConfigType
_platform_checker: ConditionCheckerType
@classmethod
async def async_validate_complete_config(
@@ -87,20 +87,20 @@ class DeviceCondition(Condition):
assert config.options is not None
self._config = config.options
async def async_get_checker(self) -> ConditionChecker:
async def async_setup(self) -> None:
"""Test a device condition."""
platform = await async_get_device_automation_platform(
self._hass, self._config[CONF_DOMAIN], DeviceAutomationType.CONDITION
)
platform_checker = platform.async_condition_from_config(
self._platform_checker = platform.async_condition_from_config(
self._hass, self._config
)
def checker(variables: TemplateVarsType = None, **kwargs: Any) -> bool:
result = platform_checker(self._hass, variables)
return result is not False
return checker
@callback
def _async_check(self, variables: TemplateVarsType = None, **kwargs: Any) -> bool:
"""Check the condition."""
result = self._platform_checker(self._hass, variables)
return result is not False
CONDITIONS: dict[str, type[Condition]] = {
@@ -8,6 +8,11 @@
options:
- all
- any
for:
required: true
default: 00:00:00
selector:
duration:
is_closed:
fields: *condition_common_fields
@@ -1,6 +1,7 @@
{
"common": {
"condition_behavior_name": "Condition passes if",
"condition_for_name": "For at least",
"trigger_behavior_name": "Trigger when",
"trigger_for_name": "For at least"
},
@@ -10,6 +11,9 @@
"fields": {
"behavior": {
"name": "[%key:component::door::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::door::common::condition_for_name%]"
}
},
"name": "Door is closed"
@@ -19,6 +23,9 @@
"fields": {
"behavior": {
"name": "[%key:component::door::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::door::common::condition_for_name%]"
}
},
"name": "Door is open"
@@ -8,6 +8,11 @@
options:
- all
- any
for:
required: true
default: 00:00:00
selector:
duration:
is_closed:
fields: *condition_common_fields
@@ -1,6 +1,7 @@
{
"common": {
"condition_behavior_name": "Condition passes if",
"condition_for_name": "For at least",
"trigger_behavior_name": "Trigger when",
"trigger_for_name": "For at least"
},
@@ -10,6 +11,9 @@
"fields": {
"behavior": {
"name": "[%key:component::garage_door::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::garage_door::common::condition_for_name%]"
}
},
"name": "Garage door is closed"
@@ -19,6 +23,9 @@
"fields": {
"behavior": {
"name": "[%key:component::garage_door::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::garage_door::common::condition_for_name%]"
}
},
"name": "Garage door is open"
@@ -8,6 +8,11 @@
options:
- all
- any
for:
required: true
default: 00:00:00
selector:
duration:
is_closed:
fields: *condition_common_fields
@@ -1,6 +1,7 @@
{
"common": {
"condition_behavior_name": "Condition passes if",
"condition_for_name": "For at least",
"trigger_behavior_name": "Trigger when",
"trigger_for_name": "For at least"
},
@@ -10,6 +11,9 @@
"fields": {
"behavior": {
"name": "[%key:component::gate::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::gate::common::condition_for_name%]"
}
},
"name": "Gate is closed"
@@ -19,6 +23,9 @@
"fields": {
"behavior": {
"name": "[%key:component::gate::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::gate::common::condition_for_name%]"
}
},
"name": "Gate is open"
+15 -14
View File
@@ -8,12 +8,11 @@ from typing import Any, Unpack, cast
import voluptuous as vol
from homeassistant.const import CONF_OPTIONS, SUN_EVENT_SUNRISE, SUN_EVENT_SUNSET
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
from homeassistant.helpers.condition import (
Condition,
ConditionChecker,
ConditionCheckParams,
ConditionConfig,
condition_trace_set_result,
@@ -151,19 +150,21 @@ class SunCondition(Condition):
super().__init__(hass, config)
assert config.options is not None
self._options = config.options
self._before = self._options.get("before")
self._after = self._options.get("after")
self._before_offset = self._options.get("before_offset")
self._after_offset = self._options.get("after_offset")
async def async_get_checker(self) -> ConditionChecker:
"""Wrap action method with sun based condition."""
before = self._options.get("before")
after = self._options.get("after")
before_offset = self._options.get("before_offset")
after_offset = self._options.get("after_offset")
def sun_if(**kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Validate time based if-condition."""
return sun(self._hass, before, after, before_offset, after_offset)
return sun_if
@callback
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Check the condition."""
return sun(
self._hass,
self._before,
self._after,
self._before_offset,
self._after_offset,
)
CONDITIONS: dict[str, type[Condition]] = {
@@ -1024,10 +1024,11 @@ async def handle_test_condition(
# Do static + dynamic validation of the condition
config = await async_validate_condition_config(hass, msg["condition"])
# Test the condition
check_condition = await async_condition_from_config(hass, config)
condition = await async_condition_from_config(hass, config)
connection.send_result(
msg["id"], {"result": check_condition(hass, msg.get("variables"))}
msg["id"], {"result": condition.async_check(variables=msg.get("variables"))}
)
condition.async_unload()
@decorators.websocket_command(
@@ -8,6 +8,11 @@
options:
- all
- any
for:
required: true
default: 00:00:00
selector:
duration:
is_closed:
fields: *condition_common_fields
@@ -1,6 +1,7 @@
{
"common": {
"condition_behavior_name": "Condition passes if",
"condition_for_name": "For at least",
"trigger_behavior_name": "Trigger when",
"trigger_for_name": "For at least"
},
@@ -10,6 +11,9 @@
"fields": {
"behavior": {
"name": "[%key:component::window::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::window::common::condition_for_name%]"
}
},
"name": "Window is closed"
@@ -19,6 +23,9 @@
"fields": {
"behavior": {
"name": "[%key:component::window::common::condition_behavior_name%]"
},
"for": {
"name": "[%key:component::window::common::condition_for_name%]"
}
},
"name": "Window is open"
+29 -34
View File
@@ -16,13 +16,12 @@ from homeassistant.const import (
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import HomeAssistant, State
from homeassistant.core import HomeAssistant, State, callback
from homeassistant.exceptions import ConditionErrorContainer, ConditionErrorMessage
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.automation import move_top_level_schema_fields_to_options
from homeassistant.helpers.condition import (
Condition,
ConditionChecker,
ConditionCheckParams,
ConditionConfig,
)
@@ -117,44 +116,40 @@ class ZoneCondition(Condition):
super().__init__(hass, config)
assert config.options is not None
self._options = config.options
self._entity_ids = self._options.get(CONF_ENTITY_ID, [])
self._zone_entity_ids = self._options.get(CONF_ZONE, [])
async def async_get_checker(self) -> ConditionChecker:
"""Wrap action method with zone based condition."""
entity_ids = self._options.get(CONF_ENTITY_ID, [])
zone_entity_ids = self._options.get(CONF_ZONE, [])
@callback
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Test if condition."""
errors = []
def if_in_zone(**kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Test if condition."""
errors = []
all_ok = True
for entity_id in entity_ids:
entity_ok = False
for zone_entity_id in zone_entity_ids:
try:
if zone(self._hass, zone_entity_id, entity_id):
entity_ok = True
except ConditionErrorMessage as ex:
errors.append(
ConditionErrorMessage(
"zone",
(
f"error matching {entity_id} with {zone_entity_id}:"
f" {ex.message}"
),
)
all_ok = True
for entity_id in self._entity_ids:
entity_ok = False
for zone_entity_id in self._zone_entity_ids:
try:
if zone(self._hass, zone_entity_id, entity_id):
entity_ok = True
except ConditionErrorMessage as ex:
errors.append(
ConditionErrorMessage(
"zone",
(
f"error matching {entity_id} with {zone_entity_id}:"
f" {ex.message}"
),
)
)
if not entity_ok:
all_ok = False
if not entity_ok:
all_ok = False
# Raise the errors only if no definitive result was found
if errors and not all_ok:
raise ConditionErrorContainer("zone", errors=errors)
# Raise the errors only if no definitive result was found
if errors and not all_ok:
raise ConditionErrorContainer("zone", errors=errors)
return all_ok
return if_in_zone
return all_ok
CONDITIONS: dict[str, type[Condition]] = {
+284 -116
View File
@@ -93,7 +93,12 @@ from .selector import (
NumericThresholdType,
TargetSelector,
)
from .target import TargetSelection, async_extract_referenced_entity_ids
from .target import (
TargetSelection,
TargetStateChangedData,
async_extract_referenced_entity_ids,
async_track_target_selector_state_change_event,
)
from .template import Template, render_complex
from .trace import (
TraceElement,
@@ -284,10 +289,97 @@ _CONDITION_SCHEMA = _CONDITION_BASE_SCHEMA.extend(
)
class Condition(abc.ABC):
"""Condition class."""
class ConditionChecker(abc.ABC):
"""Base class for condition checkers."""
_hass: HomeAssistant
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize condition checker."""
self._hass = hass
self._on_unload: list[Callable[[], None]] = []
def __call__(
self, _: HomeAssistant, variables: TemplateVarsType | None = None
) -> bool | None:
"""Check the condition."""
return self.async_check(variables=variables)
def __del__(self) -> None:
"""Clean up when the checker is deleted."""
try:
self.async_unload()
except Exception:
_LOGGER.exception("Error while unloading condition checker")
def async_check(
self, *, variables: TemplateVarsType | None = None, **kwargs: Any
) -> bool | None:
"""Check the condition."""
with trace_condition(variables):
result = self._async_check(variables=variables, **kwargs)
condition_trace_update_result(result=result)
return result
async def async_setup(self) -> None:
"""Set up the condition checker.
Intended to be overridden in derived classes that need to do async setup.
"""
@callback
def async_on_unload(self, func: Callable[[], None]) -> None:
"""Add a function to call when config entry is unloaded."""
self._on_unload.append(func)
def async_unload(self) -> None:
"""Clean up any resources held by the checker."""
for cb in self._on_unload:
cb()
self._on_unload.clear()
@abc.abstractmethod
@callback
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool | None:
"""Check the condition."""
class LegacyConditionChecker(ConditionChecker):
"""Condition checker wrapping a legacy condition factory function."""
def __init__(self, hass: HomeAssistant, checker: ConditionCheckerType) -> None:
"""Initialize condition checker."""
super().__init__(hass)
self._checker = checker
@callback
def _async_check(self, variables: TemplateVarsType = None, **kwargs: Any) -> bool:
return self._checker(self._hass, variables)
class DisabledConditionChecker(ConditionChecker):
"""Condition checker for disabled conditions."""
@callback
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> None:
return None
class CompoundConditionChecker(ConditionChecker):
"""Base class for compound condition checkers (and/or/not)."""
def __init__(self, hass: HomeAssistant, checks: list[ConditionChecker]) -> None:
"""Initialize condition checker."""
super().__init__(hass)
self._checks = checks
def async_unload(self) -> None:
"""Clean up child conditions."""
for check in self._checks:
check.async_unload()
super().async_unload()
class Condition(ConditionChecker):
"""Condition class."""
@classmethod
async def async_validate_complete_config(
@@ -323,11 +415,7 @@ class Condition(abc.ABC):
def __init__(self, hass: HomeAssistant, config: ConditionConfig) -> None:
"""Initialize condition."""
self._hass = hass
@abc.abstractmethod
async def async_get_checker(self) -> ConditionChecker:
"""Get the condition checker."""
super().__init__(hass)
ATTR_BEHAVIOR: Final = "behavior"
@@ -376,14 +464,99 @@ class EntityConditionBase(Condition):
if TYPE_CHECKING:
assert config.target
assert config.options
self._target_config = config.target
self._target_selection = TargetSelection(config.target)
self._behavior = config.options[ATTR_BEHAVIOR]
self._duration: timedelta | None = config.options.get(CONF_FOR)
if self._behavior == BEHAVIOR_ANY:
self._matcher = self._check_any_match_state
elif self._behavior == BEHAVIOR_ALL:
self._matcher = self._check_all_match_state
self._valid_since: dict[str, datetime] = {}
def entity_filter(self, entities: set[str]) -> set[str]:
"""Filter entities matching any of the domain specs."""
return filter_by_domain_specs(self._hass, self._domain_specs, entities)
@property
def _needs_duration_tracking(self) -> bool:
"""Whether this condition needs active state change tracking for duration.
Conditions that are true for a single main state value can use
state.last_changed directly. Conditions that track attributes or
match multiple states need active tracking because last_changed
does not capture those transitions.
"""
return True
def _prime_valid_since(self, entity_id: str) -> None:
"""Prime _valid_since for an entity already in a valid state.
For state-based conditions (value_source is None), last_changed
accurately reflects when the state changed to the current value.
For attribute-based conditions, last_changed only tracks main state
changes, so we use last_updated which is bumped on any update
(state or attributes). This is conservative - the tracked attribute
may have held its value longer - but it's the best we can do
to avoid false positives.
"""
if (
(_state := self._hass.states.get(entity_id)) is not None
and _state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
and self.is_valid_state(_state)
):
domain_spec = self._domain_specs[_state.domain]
if domain_spec.value_source is None:
self._valid_since[entity_id] = _state.last_changed
else:
self._valid_since[entity_id] = _state.last_updated
@override
async def async_setup(self) -> None:
"""Set up state tracking for duration-based conditions."""
await super().async_setup()
if not self._duration or not self._needs_duration_tracking:
return
@callback
def _state_change_listener(
data: TargetStateChangedData,
) -> None:
"""Track when entities enter or leave a valid state."""
event = data.state_change_event
entity_id = event.data["entity_id"]
to_state = event.data["new_state"]
if (
to_state is not None
and to_state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
and self.is_valid_state(to_state)
):
if entity_id not in self._valid_since:
# Prime from the state object — this handles both
# genuine transitions and newly tracked entities
self._prime_valid_since(entity_id)
else:
self._valid_since.pop(entity_id, None)
@callback
def _on_entities_update(added: set[str], removed: set[str]) -> None:
"""Handle changes to the tracked entity set."""
for entity_id in added:
self._prime_valid_since(entity_id)
for entity_id in removed:
self._valid_since.pop(entity_id, None)
self.async_on_unload(
async_track_target_selector_state_change_event(
self._hass,
self._target_config,
_state_change_listener,
self.entity_filter,
_on_entities_update,
)
)
def _get_tracked_value(self, entity_state: State) -> Any:
"""Get the tracked value from a state based on the DomainSpec."""
domain_spec = self._domain_specs[entity_state.domain]
@@ -395,56 +568,58 @@ class EntityConditionBase(Condition):
def is_valid_state(self, entity_state: State) -> bool:
"""Check if the state matches the expected state(s)."""
@override
async def async_get_checker(self) -> ConditionChecker:
"""Get the condition checker."""
def check_any_match_state(states: list[State]) -> bool:
"""Test if any entity matches the state."""
if not self._duration:
# Skip duration check if duration is not specified or 0
return any(self.is_valid_state(state) for state in states)
duration = dt_util.utcnow() - self._duration
def _check_any_match_state(self, states: list[State]) -> bool:
"""Test if any entity matches the state."""
if not self._duration:
# Skip duration check if duration is not specified or 0
return any(self.is_valid_state(state) for state in states)
cutoff = dt_util.utcnow() - self._duration
if not self._needs_duration_tracking:
return any(
self.is_valid_state(state) and duration > state.last_changed
self.is_valid_state(state) and state.last_changed <= cutoff
for state in states
)
return any(
self.is_valid_state(state)
and (valid_since := self._valid_since.get(state.entity_id)) is not None
and valid_since <= cutoff
for state in states
)
def check_all_match_state(states: list[State]) -> bool:
"""Test if all entities match the state."""
if not self._duration:
# Skip duration check if duration is not specified or 0
return all(self.is_valid_state(state) for state in states)
duration = dt_util.utcnow() - self._duration
def _check_all_match_state(self, states: list[State]) -> bool:
"""Test if all entities match the state."""
if not self._duration:
# Skip duration check if duration is not specified or 0
return all(self.is_valid_state(state) for state in states)
cutoff = dt_util.utcnow() - self._duration
if not self._needs_duration_tracking:
return all(
self.is_valid_state(state) and duration > state.last_changed
self.is_valid_state(state) and state.last_changed <= cutoff
for state in states
)
return all(
self.is_valid_state(state)
and (valid_since := self._valid_since.get(state.entity_id)) is not None
and valid_since <= cutoff
for state in states
)
matcher: Callable[[list[State]], bool]
if self._behavior == BEHAVIOR_ANY:
matcher = check_any_match_state
elif self._behavior == BEHAVIOR_ALL:
matcher = check_all_match_state
def test_state(**kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Test state condition."""
targeted_entities = async_extract_referenced_entity_ids(
self._hass, self._target_selection, expand_group=False
)
referenced_entity_ids = targeted_entities.referenced.union(
targeted_entities.indirectly_referenced
)
filtered_entity_ids = self.entity_filter(referenced_entity_ids)
entity_states = [
_state
for entity_id in filtered_entity_ids
if (_state := self._hass.states.get(entity_id))
and _state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
]
return matcher(entity_states)
return test_state
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Test state condition."""
targeted_entities = async_extract_referenced_entity_ids(
self._hass, self._target_selection, expand_group=False
)
referenced_entity_ids = targeted_entities.referenced.union(
targeted_entities.indirectly_referenced
)
filtered_entity_ids = self.entity_filter(referenced_entity_ids)
entity_states = [
_state
for entity_id in filtered_entity_ids
if (_state := self._hass.states.get(entity_id))
and _state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
]
return self._matcher(entity_states)
class EntityStateConditionBase(EntityConditionBase):
@@ -452,6 +627,15 @@ class EntityStateConditionBase(EntityConditionBase):
_states: set[str | bool]
@property
def _needs_duration_tracking(self) -> bool:
"""Single-state conditions with no attribute tracking can use last_changed."""
if len(self._states) != 1:
return True
return any(
spec.value_source is not None for spec in self._domain_specs.values()
)
def is_valid_state(self, entity_state: State) -> bool:
"""Check if the state matches the expected state(s)."""
return self._get_tracked_value(entity_state) in self._states
@@ -739,13 +923,6 @@ class ConditionCheckParams(TypedDict, total=False):
variables: TemplateVarsType
class ConditionChecker(Protocol):
"""Protocol for condition checker callable with typed kwargs."""
def __call__(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Check the condition."""
type ConditionCheckerType = Callable[[HomeAssistant, TemplateVarsType], bool]
type ConditionCheckerTypeOptional = Callable[
[HomeAssistant, TemplateVarsType], bool | None
@@ -869,20 +1046,10 @@ async def _async_get_condition_platform(
return platform, platform_module
async def _async_get_checker(condition: Condition) -> ConditionCheckerType:
new_checker = await condition.async_get_checker()
@trace_condition_function
def checker(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
return new_checker(variables=variables)
return checker
async def async_from_config(
hass: HomeAssistant,
config: ConfigType,
) -> ConditionCheckerTypeOptional:
) -> ConditionChecker:
"""Turn a condition configuration into a method.
Should be run on the event loop.
@@ -898,15 +1065,7 @@ async def async_from_config(
f"Error rendering condition enabled template: {err}"
) from err
if not enabled:
@trace_condition_function
def disabled_condition(
hass: HomeAssistant, variables: TemplateVarsType = None
) -> bool | None:
"""Condition not enabled, will act as if it didn't exist."""
return None
return disabled_condition
return DisabledConditionChecker(hass)
condition_key: str = config[CONF_CONDITION]
factory: Any = None
@@ -925,7 +1084,8 @@ async def async_from_config(
target=config.get(CONF_TARGET),
),
)
return await _async_get_checker(condition)
await condition.async_setup()
return condition
for fmt in (ASYNC_FROM_CONFIG_FORMAT, FROM_CONFIG_FORMAT):
factory = getattr(sys.modules[__name__], fmt.format(condition_key), None)
@@ -939,30 +1099,39 @@ async def async_from_config(
check_factory = check_factory.func
if inspect.iscoroutinefunction(check_factory):
return cast(ConditionCheckerType, await factory(hass, config))
return cast(ConditionCheckerType, factory(config))
checker = await factory(hass, config)
else:
checker = factory(config)
if isinstance(checker, ConditionChecker):
return checker
return LegacyConditionChecker(hass, cast(ConditionCheckerType, checker))
async def async_and_from_config(
hass: HomeAssistant, config: ConfigType
) -> ConditionCheckerType:
) -> ConditionChecker:
"""Create multi condition matcher using 'AND'."""
checks = [await async_from_config(hass, entry) for entry in config["conditions"]]
return AndConditionChecker(hass, checks)
@trace_condition_function
def if_and_condition(
hass: HomeAssistant, variables: TemplateVarsType = None
) -> bool:
class AndConditionChecker(CompoundConditionChecker):
"""Condition checker for 'and' compound conditions."""
@callback
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Test and condition."""
errors = []
for index, check in enumerate(checks):
for index, check in enumerate(self._checks):
try:
with trace_path(["conditions", str(index)]):
if check(hass, variables) is False:
if check(self._hass, **kwargs) is False:
return False
except ConditionError as ex:
errors.append(
ConditionErrorIndex("and", index=index, total=len(checks), error=ex)
ConditionErrorIndex(
"and", index=index, total=len(self._checks), error=ex
)
)
# Raise the errors if no check was false
@@ -971,29 +1140,32 @@ async def async_and_from_config(
return True
return if_and_condition
async def async_or_from_config(
hass: HomeAssistant, config: ConfigType
) -> ConditionCheckerType:
) -> ConditionChecker:
"""Create multi condition matcher using 'OR'."""
checks = [await async_from_config(hass, entry) for entry in config["conditions"]]
return OrConditionChecker(hass, checks)
@trace_condition_function
def if_or_condition(
hass: HomeAssistant, variables: TemplateVarsType = None
) -> bool:
class OrConditionChecker(CompoundConditionChecker):
"""Condition checker for 'or' compound conditions."""
@callback
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Test or condition."""
errors = []
for index, check in enumerate(checks):
for index, check in enumerate(self._checks):
try:
with trace_path(["conditions", str(index)]):
if check(hass, variables) is True:
if check(self._hass, **kwargs) is True:
return True
except ConditionError as ex:
errors.append(
ConditionErrorIndex("or", index=index, total=len(checks), error=ex)
ConditionErrorIndex(
"or", index=index, total=len(self._checks), error=ex
)
)
# Raise the errors if no check was true
@@ -1002,29 +1174,32 @@ async def async_or_from_config(
return False
return if_or_condition
async def async_not_from_config(
hass: HomeAssistant, config: ConfigType
) -> ConditionCheckerType:
) -> ConditionChecker:
"""Create multi condition matcher using 'NOT'."""
checks = [await async_from_config(hass, entry) for entry in config["conditions"]]
return NotConditionChecker(hass, checks)
@trace_condition_function
def if_not_condition(
hass: HomeAssistant, variables: TemplateVarsType = None
) -> bool:
class NotConditionChecker(CompoundConditionChecker):
"""Condition checker for 'not' compound conditions."""
@callback
def _async_check(self, **kwargs: Unpack[ConditionCheckParams]) -> bool:
"""Test not condition."""
errors = []
for index, check in enumerate(checks):
for index, check in enumerate(self._checks):
try:
with trace_path(["conditions", str(index)]):
if check(hass, variables):
if check(self._hass, **kwargs):
return False
except ConditionError as ex:
errors.append(
ConditionErrorIndex("not", index=index, total=len(checks), error=ex)
ConditionErrorIndex(
"not", index=index, total=len(self._checks), error=ex
)
)
# Raise the errors if no check was true
@@ -1033,8 +1208,6 @@ async def async_not_from_config(
return True
return if_not_condition
def numeric_state(
hass: HomeAssistant,
@@ -1191,7 +1364,6 @@ def async_numeric_state_from_config(config: ConfigType) -> ConditionCheckerType:
above = config.get(CONF_ABOVE)
value_template = config.get(CONF_VALUE_TEMPLATE)
@trace_condition_function
def if_numeric_state(
hass: HomeAssistant, variables: TemplateVarsType = None
) -> bool:
@@ -1310,7 +1482,6 @@ def state_from_config(config: ConfigType) -> ConditionCheckerType:
if not isinstance(req_states, list):
req_states = [req_states]
@trace_condition_function
def if_state(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Test if condition."""
errors = []
@@ -1372,7 +1543,6 @@ def async_template_from_config(config: ConfigType) -> ConditionCheckerType:
"""Wrap action method with state based condition."""
value_template = cast(Template, config.get(CONF_VALUE_TEMPLATE))
@trace_condition_function
def template_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Validate template based if-condition."""
return async_template(hass, value_template, variables)
@@ -1485,7 +1655,6 @@ def time_from_config(config: ConfigType) -> ConditionCheckerType:
after = config.get(CONF_AFTER)
weekday = config.get(CONF_WEEKDAY)
@trace_condition_function
def time_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Validate time based if-condition."""
return time(hass, before, after, weekday)
@@ -1499,7 +1668,6 @@ async def async_trigger_from_config(
"""Test a trigger condition."""
trigger_id = config[CONF_ID]
@trace_condition_function
def trigger_if(hass: HomeAssistant, variables: TemplateVarsType = None) -> bool:
"""Validate trigger based if-condition."""
return (
+15 -1
View File
@@ -345,14 +345,25 @@ class TargetStateChangeTracker(TargetEntityChangeTracker):
target_selection: TargetSelection,
action: Callable[[TargetStateChangedData], Any],
entity_filter: Callable[[set[str]], set[str]],
on_entities_update: Callable[[set[str], set[str]], None] | None = None,
) -> None:
"""Initialize the state change tracker."""
super().__init__(hass, target_selection, entity_filter)
self._action = action
self._on_entities_update = on_entities_update
self._state_change_unsub: CALLBACK_TYPE | None = None
self._tracked_entities: set[str] = set()
def _handle_entities_update(self, tracked_entities: set[str]) -> None:
"""Handle the tracked entities."""
previous_entities = self._tracked_entities
self._tracked_entities = tracked_entities
if self._on_entities_update is not None:
added = tracked_entities - previous_entities
removed = previous_entities - tracked_entities
if added or removed:
self._on_entities_update(added, removed)
@callback
def state_change_listener(event: Event[EventStateChangedData]) -> None:
@@ -380,6 +391,7 @@ def async_track_target_selector_state_change_event(
target_selector_config: ConfigType,
action: Callable[[TargetStateChangedData], Any],
entity_filter: Callable[[set[str]], set[str]] = lambda x: x,
on_entities_update: Callable[[set[str], set[str]], None] | None = None,
) -> CALLBACK_TYPE:
"""Track state changes for entities referenced directly or indirectly in a target selector."""
target_selection = TargetSelection(target_selector_config)
@@ -387,5 +399,7 @@ def async_track_target_selector_state_change_event(
raise HomeAssistantError(
f"Target selector {target_selector_config} does not have any selectors defined"
)
tracker = TargetStateChangeTracker(hass, target_selection, action, entity_filter)
tracker = TargetStateChangeTracker(
hass, target_selection, action, entity_filter, on_entities_update
)
return tracker.async_setup()
+572 -13
View File
@@ -2,6 +2,7 @@
from collections.abc import Mapping
from contextlib import AbstractContextManager, nullcontext as does_not_raise
from dataclasses import dataclass, field
from datetime import timedelta
import io
from typing import Any
@@ -22,6 +23,7 @@ from homeassistant.components.sun import DOMAIN as SUN_DOMAIN
from homeassistant.components.system_health import DOMAIN as SYSTEM_HEALTH_DOMAIN
from homeassistant.const import (
ATTR_DEVICE_CLASS,
ATTR_LABEL_ID,
ATTR_UNIT_OF_MEASUREMENT,
CONF_CONDITION,
CONF_DEVICE_ID,
@@ -42,6 +44,7 @@ from homeassistant.helpers import (
condition,
config_validation as cv,
entity_registry as er,
label_registry as lr,
trace,
)
from homeassistant.helpers.automation import (
@@ -54,7 +57,6 @@ from homeassistant.helpers.condition import (
BEHAVIOR_ANY,
CONDITIONS,
Condition,
ConditionChecker,
EntityNumericalConditionWithUnitBase,
_async_get_condition_platform,
async_validate_condition_config,
@@ -63,7 +65,12 @@ from homeassistant.helpers.condition import (
make_entity_state_condition,
)
from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import UNDEFINED, ConfigType, UndefinedType
from homeassistant.helpers.typing import (
UNDEFINED,
ConfigType,
TemplateVarsType,
UndefinedType,
)
from homeassistant.loader import Integration, async_get_integration
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
@@ -2150,16 +2157,16 @@ async def test_platform_multiple_conditions(hass: HomeAssistant) -> None:
class MockCondition1(MockCondition):
"""Mock condition 1."""
async def async_get_checker(self) -> ConditionChecker:
"""Evaluate state based on configuration."""
return lambda **kwargs: True
def _async_check(self, variables: TemplateVarsType) -> bool:
"""Check the condition."""
return True
class MockCondition2(MockCondition):
"""Mock condition 2."""
async def async_get_checker(self) -> ConditionChecker:
"""Evaluate state based on configuration."""
return lambda **kwargs: False
def _async_check(self, variables: TemplateVarsType) -> bool:
"""Check the condition."""
return False
async def async_get_conditions(hass: HomeAssistant) -> dict[str, type[Condition]]:
return {
@@ -2297,8 +2304,9 @@ async def test_get_condition_platform_registers_conditions(
) -> ConfigType:
return config
async def async_get_checker(self) -> ConditionChecker:
return lambda **kwargs: True
def _async_check(self, variables: TemplateVarsType) -> bool:
"""Check the condition."""
return True
async def async_get_conditions(
hass: HomeAssistant,
@@ -3103,7 +3111,7 @@ async def _setup_numerical_condition(
entity_ids: str | list[str],
domain_specs: Mapping[str, DomainSpec] | None = None,
valid_unit: str | None | UndefinedType = UNDEFINED,
) -> condition.ConditionCheckerType:
) -> condition.ConditionChecker:
"""Set up a numerical condition via a mock platform and return the test."""
condition_cls = make_entity_numerical_condition(
domain_specs or _DEFAULT_DOMAIN_SPECS, valid_unit
@@ -3432,7 +3440,7 @@ async def _setup_numerical_condition_with_unit(
domain_specs: Mapping[str, DomainSpec] | None = None,
base_unit: str = UnitOfTemperature.CELSIUS,
unit_converter: type = TemperatureConverter,
) -> condition.ConditionCheckerType:
) -> condition.ConditionChecker:
"""Set up a numerical condition with unit conversion via a mock platform."""
condition_cls = make_entity_numerical_condition_with_unit(
domain_specs or _DEFAULT_DOMAIN_SPECS, base_unit, unit_converter
@@ -3953,7 +3961,7 @@ async def _setup_state_condition(
condition_options: dict[str, Any] | None = None,
domain_specs: Mapping[str, DomainSpec] | None = None,
support_duration: bool = False,
) -> condition.ConditionCheckerType:
) -> condition.ConditionChecker:
"""Set up a state condition via a mock platform and return the checker."""
condition_cls = make_entity_state_condition(
domain_specs or _DEFAULT_DOMAIN_SPECS,
@@ -4341,3 +4349,554 @@ async def test_state_condition_duration_unavailable_unknown(
await hass.async_block_till_done()
freezer.tick(timedelta(seconds=11))
assert test_all(hass) is False
_ATTR_DOMAIN_SPECS: Mapping[str, DomainSpec] = {
"test": DomainSpec(value_source="test_attr")
}
async def _setup_attr_state_condition(
hass: HomeAssistant,
entity_ids: str | list[str],
states: str | bool | set[str | bool],
condition_options: dict[str, Any] | None = None,
) -> condition.ConditionChecker:
"""Set up an attribute-based state condition and return the checker."""
condition_cls = make_entity_state_condition(
_ATTR_DOMAIN_SPECS,
states,
support_duration=True,
)
async def async_get_conditions(
hass: HomeAssistant,
) -> dict[str, type[Condition]]:
return {"_": condition_cls}
mock_integration(hass, MockModule("test"))
mock_platform(
hass, "test.condition", Mock(async_get_conditions=async_get_conditions)
)
if isinstance(entity_ids, str):
entity_ids = [entity_ids]
config: dict[str, Any] = {
CONF_CONDITION: "test",
CONF_TARGET: {CONF_ENTITY_ID: entity_ids},
CONF_OPTIONS: condition_options or {},
}
config = await async_validate_condition_config(hass, config)
test = await condition.async_from_config(hass, config)
assert test is not None
return test
async def test_state_condition_attr_duration_not_met(
hass: HomeAssistant, freezer: FrozenDateTimeFactory
) -> None:
"""Test attribute-based condition with duration: not met yet."""
test = await _setup_attr_state_condition(
hass,
entity_ids="test.entity_1",
states={True},
condition_options={CONF_FOR: {"seconds": 10}},
)
hass.states.async_set("test.entity_1", STATE_ON, {"test_attr": True})
await hass.async_block_till_done()
# Just set — duration not met
assert test(hass) is False
freezer.tick(timedelta(seconds=5))
assert test(hass) is False
async def test_state_condition_attr_duration_met(
hass: HomeAssistant, freezer: FrozenDateTimeFactory
) -> None:
"""Test attribute-based condition with duration: met after waiting."""
test = await _setup_attr_state_condition(
hass,
entity_ids="test.entity_1",
states={True},
condition_options={CONF_FOR: {"seconds": 10}},
)
hass.states.async_set("test.entity_1", STATE_ON, {"test_attr": True})
await hass.async_block_till_done()
freezer.tick(timedelta(seconds=11))
assert test(hass) is True
async def test_state_condition_attr_duration_reset_on_attr_change(
hass: HomeAssistant, freezer: FrozenDateTimeFactory
) -> None:
"""Test attribute-based condition: timer resets when attribute changes.
This is the key difference from state-based duration: the tracked value
is in an attribute, so state.last_changed does not capture it. The
_valid_since tracking in async_setup handles this correctly.
"""
test = await _setup_attr_state_condition(
hass,
entity_ids="test.entity_1",
states={True},
condition_options={CONF_FOR: {"seconds": 10}},
)
# Set attribute to True
hass.states.async_set("test.entity_1", STATE_ON, {"test_attr": True})
await hass.async_block_till_done()
# After 8s, change attribute to False (state stays the same)
freezer.tick(timedelta(seconds=8))
hass.states.async_set("test.entity_1", STATE_ON, {"test_attr": False})
await hass.async_block_till_done()
# Set attribute back to True
hass.states.async_set("test.entity_1", STATE_ON, {"test_attr": True})
await hass.async_block_till_done()
# 5s after re-set — not enough (timer was reset)
freezer.tick(timedelta(seconds=5))
assert test(hass) is False
# 6 more seconds (11 from re-set) — now met
freezer.tick(timedelta(seconds=6))
assert test(hass) is True
@pytest.mark.parametrize(
("behavior", "one_match_expected"),
[(BEHAVIOR_ANY, True), (BEHAVIOR_ALL, False)],
)
async def test_state_condition_attr_duration_behavior(
hass: HomeAssistant,
freezer: FrozenDateTimeFactory,
behavior: str,
one_match_expected: bool,
) -> None:
"""Test attribute-based condition with duration and behavior any/all."""
test = await _setup_attr_state_condition(
hass,
entity_ids=["test.entity_1", "test.entity_2"],
states={True},
condition_options={ATTR_BEHAVIOR: behavior, CONF_FOR: {"seconds": 10}},
)
hass.states.async_set("test.entity_1", STATE_ON, {"test_attr": True})
hass.states.async_set("test.entity_2", STATE_ON, {"test_attr": True})
await hass.async_block_till_done()
# Both matching but duration not met
assert test(hass) is False
# Advance past duration — both matching long enough
freezer.tick(timedelta(seconds=11))
assert test(hass) is True
# Change entity_2 attribute — only one matching for duration
hass.states.async_set("test.entity_2", STATE_ON, {"test_attr": False})
await hass.async_block_till_done()
assert test(hass) is one_match_expected
@dataclass
class _AttrInitStep:
"""A state update step before the condition is created."""
state: str
attrs: dict[str, Any] = field(default_factory=dict)
delay_before: int = 0
@pytest.mark.parametrize(
("steps", "duration", "initially_met"),
[
# Attribute set to valid 10s ago, no further changes → met (10 >= 5)
(
[_AttrInitStep(STATE_ON, {"test_attr": True})],
10,
True,
),
# Attribute set to valid 3s ago → not met (3 < 5)
(
[_AttrInitStep(STATE_ON, {"test_attr": True})],
3,
False,
),
# Attribute set to valid, then main state changes 2s later
# (attribute stays valid). last_updated is bumped by the state change,
# so the effective duration is only 2s from the second update → not met
(
[
_AttrInitStep(STATE_ON, {"test_attr": True}),
_AttrInitStep(STATE_OFF, {"test_attr": True}, delay_before=8),
],
2,
False,
),
# Same as above but enough time after the state change → met
(
[
_AttrInitStep(STATE_ON, {"test_attr": True}),
_AttrInitStep(STATE_OFF, {"test_attr": True}, delay_before=2),
],
8,
True,
),
# Attribute was invalid, then set to valid 4s ago → not met (4 < 5)
(
[
_AttrInitStep(STATE_ON, {"test_attr": False}),
_AttrInitStep(STATE_ON, {"test_attr": True}, delay_before=6),
],
4,
False,
),
# Attribute was invalid, then set to valid 6s ago → met (6 >= 5)
(
[
_AttrInitStep(STATE_ON, {"test_attr": False}),
_AttrInitStep(STATE_ON, {"test_attr": True}, delay_before=4),
],
6,
True,
),
# Attribute valid → invalid → valid 3s ago → not met (3 < 5)
(
[
_AttrInitStep(STATE_ON, {"test_attr": True}),
_AttrInitStep(STATE_ON, {"test_attr": False}, delay_before=5),
_AttrInitStep(STATE_ON, {"test_attr": True}, delay_before=2),
],
3,
False,
),
],
ids=[
"valid_long_enough",
"valid_too_short",
"state_change_bumps_last_updated_not_met",
"state_change_bumps_last_updated_met",
"invalid_then_valid_not_met",
"invalid_then_valid_met",
"valid_invalid_valid_not_met",
],
)
async def test_state_condition_attr_duration_initial_state(
hass: HomeAssistant,
freezer: FrozenDateTimeFactory,
steps: list[_AttrInitStep],
duration: int,
initially_met: bool,
) -> None:
"""Test attribute-based condition initialization from existing state.
The condition uses last_updated (not last_changed) to determine how long
an attribute-based condition has been true. This is conservative: when
the main state changes but the tracked attribute stays the same,
last_updated is bumped and the effective duration resets.
"""
for step in steps:
freezer.tick(timedelta(seconds=step.delay_before))
hass.states.async_set("test.entity_1", step.state, step.attrs)
await hass.async_block_till_done()
freezer.tick(timedelta(seconds=duration))
test = await _setup_attr_state_condition(
hass,
entity_ids="test.entity_1",
states={True},
condition_options={CONF_FOR: {"seconds": 5}},
)
assert test(hass) is initially_met
async def _setup_attr_state_condition_with_target(
hass: HomeAssistant,
target: dict[str, Any],
states: str | bool | set[str | bool],
condition_options: dict[str, Any] | None = None,
) -> condition.ConditionChecker:
"""Set up an attribute-based state condition with a custom target."""
condition_cls = make_entity_state_condition(
_ATTR_DOMAIN_SPECS,
states,
support_duration=True,
)
async def async_get_conditions(
hass: HomeAssistant,
) -> dict[str, type[Condition]]:
return {"_": condition_cls}
mock_integration(hass, MockModule("test"))
mock_platform(
hass, "test.condition", Mock(async_get_conditions=async_get_conditions)
)
config: dict[str, Any] = {
CONF_CONDITION: "test",
CONF_TARGET: target,
CONF_OPTIONS: condition_options or {},
}
config = await async_validate_condition_config(hass, config)
test = await condition.async_from_config(hass, config)
assert test is not None
return test
async def test_state_condition_attr_duration_entity_added_to_target(
hass: HomeAssistant, freezer: FrozenDateTimeFactory
) -> None:
"""Test that _valid_since is primed when an entity is added to the tracked set.
When targeting by label, adding a label to an entity should make it
tracked, and if it's already in a valid state, its duration should be
primed from the state timestamps.
"""
label_reg = lr.async_get(hass)
label = label_reg.async_create("Test Duration")
entity_reg = er.async_get(hass)
entry = entity_reg.async_get_or_create(
domain="test", platform="test", unique_id="duration_add"
)
# Entity starts valid but without the label
hass.states.async_set(entry.entity_id, STATE_ON, {"test_attr": True})
await hass.async_block_till_done()
# Create condition targeting the label
test = await _setup_attr_state_condition_with_target(
hass,
target={ATTR_LABEL_ID: label.label_id},
states={True},
condition_options={CONF_FOR: {"seconds": 5}},
)
# No entities have the label yet — condition has no entities to check,
# behavior "any" with no matching entities returns False
assert test(hass) is False
# Add the label to the entity — entity is already in valid state
freezer.tick(timedelta(seconds=1))
entity_reg.async_update_entity(entry.entity_id, labels={label.label_id})
await hass.async_block_till_done()
# Just added — duration not met yet
assert test(hass) is False
# Wait past the duration from when entity was last_updated
freezer.tick(timedelta(seconds=5))
assert test(hass) is True
async def test_state_condition_attr_duration_entity_removed_from_target(
hass: HomeAssistant, freezer: FrozenDateTimeFactory
) -> None:
"""Test that _valid_since is evicted when an entity is removed from the tracked set."""
label_reg = lr.async_get(hass)
label = label_reg.async_create("Test Duration Remove")
entity_reg = er.async_get(hass)
entry1 = entity_reg.async_get_or_create(
domain="test", platform="test", unique_id="duration_remove_1"
)
entry2 = entity_reg.async_get_or_create(
domain="test", platform="test", unique_id="duration_remove_2"
)
# Both entities start with the label
entity_reg.async_update_entity(entry1.entity_id, labels={label.label_id})
entity_reg.async_update_entity(entry2.entity_id, labels={label.label_id})
# Both entities in valid state
hass.states.async_set(entry1.entity_id, STATE_ON, {"test_attr": True})
hass.states.async_set(entry2.entity_id, STATE_ON, {"test_attr": True})
await hass.async_block_till_done()
test = await _setup_attr_state_condition_with_target(
hass,
target={ATTR_LABEL_ID: label.label_id},
states={True},
condition_options={
ATTR_BEHAVIOR: BEHAVIOR_ALL,
CONF_FOR: {"seconds": 5},
},
)
# Wait past duration — both valid
freezer.tick(timedelta(seconds=6))
assert test(hass) is True
# Remove label from entry2
entity_reg.async_update_entity(entry2.entity_id, labels=set())
await hass.async_block_till_done()
# Condition should still be True — only entry1 is tracked now, and it's valid
assert test(hass) is True
# Now remove label from entry1 too
entity_reg.async_update_entity(entry1.entity_id, labels=set())
await hass.async_block_till_done()
# No entities tracked — "all" with empty set is vacuously True
assert test(hass) is True
# Change entry1 to invalid state and re-add its label
hass.states.async_set(entry1.entity_id, STATE_ON, {"test_attr": False})
await hass.async_block_till_done()
entity_reg.async_update_entity(entry1.entity_id, labels={label.label_id})
await hass.async_block_till_done()
# entry1 is now tracked again but invalid — "all" fails
freezer.tick(timedelta(seconds=10))
assert test(hass) is False
async def test_state_condition_attr_duration_entity_added_then_state_changes(
hass: HomeAssistant, freezer: FrozenDateTimeFactory
) -> None:
"""Test that a newly added entity's state changes are properly tracked."""
label_reg = lr.async_get(hass)
label = label_reg.async_create("Test Duration Track")
entity_reg = er.async_get(hass)
entry = entity_reg.async_get_or_create(
domain="test", platform="test", unique_id="duration_track"
)
# Entity starts in invalid state
hass.states.async_set(entry.entity_id, STATE_ON, {"test_attr": False})
await hass.async_block_till_done()
# Create condition targeting the label
test = await _setup_attr_state_condition_with_target(
hass,
target={ATTR_LABEL_ID: label.label_id},
states={True},
condition_options={CONF_FOR: {"seconds": 5}},
)
# Add the label — entity is invalid, so no priming
entity_reg.async_update_entity(entry.entity_id, labels={label.label_id})
await hass.async_block_till_done()
assert test(hass) is False
# Now change to valid state
freezer.tick(timedelta(seconds=1))
hass.states.async_set(entry.entity_id, STATE_ON, {"test_attr": True})
await hass.async_block_till_done()
# Just became valid — not long enough
freezer.tick(timedelta(seconds=3))
assert test(hass) is False
# Now past the duration
freezer.tick(timedelta(seconds=3))
assert test(hass) is True
@pytest.mark.parametrize(
"compound_type",
["and", "or", "not"],
)
async def test_compound_condition_forwards_async_unload(
hass: HomeAssistant, compound_type: str
) -> None:
"""Test that and/or/not compound conditions forward async_unload to children."""
config = {
"condition": compound_type,
"conditions": [
{
"condition": "state",
"entity_id": "test.entity_1",
"state": STATE_ON,
},
{
"condition": "state",
"entity_id": "test.entity_2",
"state": STATE_ON,
},
],
}
config = cv.CONDITION_SCHEMA(config)
config = await condition.async_validate_condition_config(hass, config)
test = await condition.async_from_config(hass, config)
# The compound checker should hold child checkers
assert hasattr(test, "_checks")
assert len(test._checks) == 2
# Patch async_unload on children to verify forwarding
child_unloads = [Mock() for _ in test._checks]
for child, mock_unload in zip(test._checks, child_unloads, strict=True):
child.async_unload = mock_unload
test.async_unload()
for mock_unload in child_unloads:
mock_unload.assert_called_once()
@pytest.mark.parametrize(
("outer_type", "inner_type"),
[
(outer, inner)
for outer in ("and", "or", "not")
for inner in ("and", "or", "not")
],
)
async def test_nested_compound_condition_forwards_async_unload(
hass: HomeAssistant, outer_type: str, inner_type: str
) -> None:
"""Test that nested compound conditions forward async_unload recursively."""
config = {
"condition": outer_type,
"conditions": [
{
"condition": inner_type,
"conditions": [
{
"condition": "state",
"entity_id": "test.entity_1",
"state": STATE_ON,
},
],
},
{
"condition": "state",
"entity_id": "test.entity_2",
"state": STATE_ON,
},
],
}
config = cv.CONDITION_SCHEMA(config)
config = await condition.async_validate_condition_config(hass, config)
test = await condition.async_from_config(hass, config)
# Outer compound with 2 children: an inner compound and a leaf
assert len(test._checks) == 2
inner_checker = test._checks[0]
assert hasattr(inner_checker, "_checks")
assert len(inner_checker._checks) == 1
# Patch the innermost leaf's async_unload
innermost_unload = Mock()
inner_checker._checks[0].async_unload = innermost_unload
leaf_unload = Mock()
test._checks[1].async_unload = leaf_unload
test.async_unload()
innermost_unload.assert_called_once()
leaf_unload.assert_called_once()
+118
View File
@@ -762,3 +762,121 @@ async def test_async_track_target_selector_state_change_event_filter(
)
unsub()
async def test_async_track_target_selector_state_change_event_on_entities_update(
hass: HomeAssistant,
) -> None:
"""Test on_entities_update callback reports added and removed entities."""
entity_updates: list[tuple[set[str], set[str]]] = []
@callback
def state_change_callback(event: target.TargetStateChangedData) -> None:
"""Handle state change events."""
@callback
def on_entities_update(added: set[str], removed: set[str]) -> None:
"""Track entity set changes."""
entity_updates.append((added, removed))
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
entity_reg = er.async_get(hass)
label_reg = lr.async_get(hass)
label = label_reg.async_create("Track Test")
entity_a = entity_reg.async_get_or_create(
domain="light", platform="test", unique_id="track_a"
)
entity_b = entity_reg.async_get_or_create(
domain="light", platform="test", unique_id="track_b"
)
# entity_a starts with the label
entity_reg.async_update_entity(entity_a.entity_id, labels={label.label_id})
hass.states.async_set(entity_a.entity_id, STATE_ON)
hass.states.async_set(entity_b.entity_id, STATE_ON)
await hass.async_block_till_done()
unsub = target.async_track_target_selector_state_change_event(
hass,
{ATTR_LABEL_ID: label.label_id},
state_change_callback,
on_entities_update=on_entities_update,
)
# Initial setup fires on_entities_update with all entities as "added"
assert len(entity_updates) == 1
assert entity_updates[-1] == ({entity_a.entity_id}, set())
entity_updates.clear()
# Add label to entity_b → added
entity_reg.async_update_entity(entity_b.entity_id, labels={label.label_id})
await hass.async_block_till_done()
assert len(entity_updates) == 1
assert entity_updates[-1] == ({entity_b.entity_id}, set())
entity_updates.clear()
# Remove label from entity_a → removed
entity_reg.async_update_entity(entity_a.entity_id, labels=set())
await hass.async_block_till_done()
assert len(entity_updates) == 1
assert entity_updates[-1] == (set(), {entity_a.entity_id})
entity_updates.clear()
# Remove label from entity_b → removed
entity_reg.async_update_entity(entity_b.entity_id, labels=set())
await hass.async_block_till_done()
assert len(entity_updates) == 1
assert entity_updates[-1] == (set(), {entity_b.entity_id})
entity_updates.clear()
# Re-add both labels at once — entity_a first, then entity_b
entity_reg.async_update_entity(entity_a.entity_id, labels={label.label_id})
await hass.async_block_till_done()
entity_reg.async_update_entity(entity_b.entity_id, labels={label.label_id})
await hass.async_block_till_done()
assert len(entity_updates) == 2
assert entity_updates[0] == ({entity_a.entity_id}, set())
assert entity_updates[1] == ({entity_b.entity_id}, set())
entity_updates.clear()
# After unsubscribing, no more callbacks
unsub()
entity_reg.async_update_entity(entity_a.entity_id, labels=set())
await hass.async_block_till_done()
assert len(entity_updates) == 0
async def test_async_track_target_selector_no_on_entities_update(
hass: HomeAssistant,
) -> None:
"""Test that on_entities_update is optional and defaults to no callback."""
events: list[target.TargetStateChangedData] = []
@callback
def state_change_callback(event: target.TargetStateChangedData) -> None:
events.append(event)
entity_id = "light.test_no_callback"
hass.states.async_set(entity_id, STATE_ON)
await hass.async_block_till_done()
# No on_entities_update — should work without errors
unsub = target.async_track_target_selector_state_change_event(
hass,
{ATTR_ENTITY_ID: entity_id},
state_change_callback,
)
hass.states.async_set(entity_id, STATE_OFF)
await hass.async_block_till_done()
assert len(events) == 1
unsub()