Compare commits

...

2 Commits

Author SHA1 Message Date
Erik 154704f998 Fix docstring 2026-05-07 07:38:25 +02:00
Erik 4f5c44b973 Filter excluded states in entity trigger base class 2026-05-06 23:05:07 +02:00
10 changed files with 57 additions and 113 deletions
+3 -18
View File
@@ -1,11 +1,6 @@
"""Provides triggers for counters."""
from homeassistant.const import (
CONF_MAXIMUM,
CONF_MINIMUM,
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.const import CONF_MAXIMUM, CONF_MINIMUM
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
@@ -41,9 +36,7 @@ class CounterDecrementedTrigger(CounterBaseIntegerTrigger):
"""Trigger for when a counter is decremented."""
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is valid and the state has changed."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
return False
"""Check that the counter value decreased."""
return int(from_state.state) > int(to_state.state)
@@ -51,9 +44,7 @@ class CounterIncrementedTrigger(CounterBaseIntegerTrigger):
"""Trigger for when a counter is incremented."""
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is valid and the state has changed."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
return False
"""Check that the counter value increased."""
return int(from_state.state) < int(to_state.state)
@@ -62,12 +53,6 @@ class CounterValueBaseTrigger(EntityTriggerBase):
_domain_specs = {DOMAIN: DomainSpec()}
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is valid and the state has changed."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
return False
return from_state.state != to_state.state
class CounterMaxReachedTrigger(CounterValueBaseTrigger):
"""Trigger for when a counter reaches its maximum value."""
+2 -4
View File
@@ -2,7 +2,7 @@
from collections.abc import Mapping
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.trigger import EntityTriggerBase, Trigger
@@ -28,9 +28,7 @@ class CoverTriggerBase(EntityTriggerBase):
return self._get_value(state) == domain_spec.target_value
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the transition is valid for a cover state change."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
return False
"""Check that the relevant cover value changed."""
if (from_value := self._get_value(from_state)) is None:
return False
return from_value != self._get_value(to_state)
+2 -4
View File
@@ -17,10 +17,8 @@ class DoorbellRangTrigger(StatelessEntityTriggerBase):
_domain_specs = {EVENT_DOMAIN: DomainSpec(device_class=EventDeviceClass.DOORBELL)}
def is_valid_state(self, state: State) -> bool:
"""Check if the entity is available and the event type is ring."""
return super().is_valid_state(state) and (
state.attributes.get(ATTR_EVENT_TYPE) == DoorbellEventType.RING
)
"""Check if the event type is ring."""
return state.attributes.get(ATTR_EVENT_TYPE) == DoorbellEventType.RING
TRIGGERS: dict[str, type[Trigger]] = {
+1 -3
View File
@@ -41,9 +41,7 @@ class EventReceivedTrigger(StatelessEntityTriggerBase):
def is_valid_state(self, state: State) -> bool:
"""Check if the event type matches one of the configured types."""
return super().is_valid_state(state) and (
state.attributes.get(ATTR_EVENT_TYPE) in self._event_types
)
return state.attributes.get(ATTR_EVENT_TYPE) in self._event_types
TRIGGERS: dict[str, type[Trigger]] = {
@@ -1,6 +1,5 @@
"""Provides triggers for media players."""
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
@@ -50,10 +49,7 @@ class _MediaPlayerMutedStateTriggerBase(EntityTriggerBase):
)
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is valid and the state has changed."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
return False
"""Check that the muted-state changed."""
if not self._has_volume_attributes(to_state):
return False
+2 -5
View File
@@ -1,6 +1,6 @@
"""Provides triggers for schedules."""
from homeassistant.const import STATE_OFF, STATE_ON, STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
@@ -20,10 +20,7 @@ class ScheduleBackToBackTrigger(EntityTransitionTriggerBase):
_to_states = {STATE_ON}
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state matches the expected ones."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
return False
"""Check that the origin matches and the next event changed."""
from_next_event = from_state.attributes.get(ATTR_NEXT_EVENT)
to_next_event = to_state.attributes.get(ATTR_NEXT_EVENT)
+1 -12
View File
@@ -1,8 +1,7 @@
"""Provides triggers for selects."""
from homeassistant.components.input_select import DOMAIN as INPUT_SELECT_DOMAIN
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State
from homeassistant.core import HomeAssistant
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
ENTITY_STATE_TRIGGER_SCHEMA,
@@ -19,16 +18,6 @@ class SelectionChangedTrigger(EntityTriggerBase):
_domain_specs = {DOMAIN: DomainSpec(), INPUT_SELECT_DOMAIN: DomainSpec()}
_schema = ENTITY_STATE_TRIGGER_SCHEMA
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is valid and the state has changed."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
return False
return from_state.state != to_state.state
def is_valid_state(self, state: State) -> bool:
"""Check if the new state is not invalid."""
return state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
TRIGGERS: dict[str, type[Trigger]] = {
"selection_changed": SelectionChangedTrigger,
+1 -12
View File
@@ -1,8 +1,7 @@
"""Provides triggers for text and input_text entities."""
from homeassistant.components.input_text import DOMAIN as INPUT_TEXT_DOMAIN
from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN
from homeassistant.core import HomeAssistant, State
from homeassistant.core import HomeAssistant
from homeassistant.helpers.automation import DomainSpec
from homeassistant.helpers.trigger import (
ENTITY_STATE_TRIGGER_SCHEMA,
@@ -19,16 +18,6 @@ class TextChangedTrigger(EntityTriggerBase):
_domain_specs = {DOMAIN: DomainSpec(), INPUT_TEXT_DOMAIN: DomainSpec()}
_schema = ENTITY_STATE_TRIGGER_SCHEMA
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is valid and the state has changed."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
return False
return from_state.state != to_state.state
def is_valid_state(self, state: State) -> bool:
"""Check if the new state is not invalid."""
return state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN)
TRIGGERS: dict[str, type[Trigger]] = {
"changed": TextChangedTrigger,
+44 -42
View File
@@ -353,9 +353,14 @@ class EntityTriggerBase(Trigger):
"""Trigger for entity state changes."""
_domain_specs: Mapping[str, DomainSpec]
# States filtered from the to_state pre-filter (and `_should_include`).
_excluded_states: Final[frozenset[str]] = frozenset(
{STATE_UNAVAILABLE, STATE_UNKNOWN}
)
# States filtered from the from_state pre-filter. Defaults to
# `_excluded_states`. Subclasses can override to relax the origin
# check.
_excluded_from_states: ClassVar[frozenset[str]] = _excluded_states
_schema: vol.Schema = ENTITY_STATE_TRIGGER_SCHEMA_FIRST_LAST
# When True, indirect target expansion (via device/area/floor) skips
# entities with an entity_category.
@@ -389,13 +394,28 @@ class EntityTriggerBase(Trigger):
return state.state
return state.attributes.get(domain_spec.value_source)
@abc.abstractmethod
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is valid and the state has changed."""
"""Check if the transition should fire the trigger.
Called only after `from_state.state` has been filtered against
`_excluded_from_states` and `to_state.state` against
`_excluded_states`, so subclasses don't need to repeat those
checks. Default: any state change. Override to add semantics
(specific from/to states, value changed across a threshold,
etc.).
"""
return from_state.state != to_state.state
@abc.abstractmethod
def is_valid_state(self, state: State) -> bool:
"""Check if the new state matches the expected state(s)."""
"""Check if the state is a target state for the trigger.
Called only after `state.state` has been filtered against
`_excluded_states`, so subclasses don't need to repeat that
check. Default: any non-excluded state is a target. Override
to restrict (specific to_states, value within a threshold,
etc.).
"""
return True
def _should_include(self, state: State) -> bool:
"""Check if an entity should participate in all/count checks.
@@ -473,19 +493,26 @@ class EntityTriggerBase(Trigger):
)
return matches >= 1
# Behavior any: check the individual entity's state
if not to_state:
if not to_state or to_state.state in self._excluded_states:
return False
return self.is_valid_state(to_state)
if not from_state or not to_state:
return
# The trigger should never fire if the new state is not valid
if not self.is_valid_state(to_state):
# The trigger should never fire if the new state is excluded
# or not a target state.
if to_state.state in self._excluded_states or not self.is_valid_state(
to_state
):
return
# The trigger should never fire if the transition is not valid
if not self.is_valid_transition(from_state, to_state):
# The trigger should never fire if the origin state is excluded
# or the transition is not valid.
if (
from_state.state in self._excluded_from_states
or not self.is_valid_transition(from_state, to_state)
):
return
if behavior == BEHAVIOR_LAST:
@@ -570,10 +597,7 @@ class EntityTargetStateTriggerBase(EntityTriggerBase):
_to_states: set[str]
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is valid and the state has changed."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
return False
"""Check the value changed and the origin was not already a target state."""
from_value = self._get_tracked_value(from_state)
return (
from_value != self._get_tracked_value(to_state)
@@ -593,9 +617,6 @@ class EntityTransitionTriggerBase(EntityTriggerBase):
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state matches the expected ones."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
return False
from_value = self._get_tracked_value(from_state)
return (
from_value != self._get_tracked_value(to_state)
@@ -620,10 +641,8 @@ class EntityOriginStateTriggerBase(EntityTriggerBase):
)
def is_valid_state(self, state: State) -> bool:
"""Check if the new state is valid."""
return state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) and bool(
self._get_tracked_value(state) != self._from_state
)
"""Check that the new state is different from the origin state."""
return bool(self._get_tracked_value(state) != self._from_state)
class StatelessEntityTriggerBase(EntityTriggerBase):
@@ -631,23 +650,12 @@ class StatelessEntityTriggerBase(EntityTriggerBase):
Used for stateless entities (buttons, scenes, doorbells, events)
whose `state.state` is just a timestamp of the last activation.
`STATE_UNKNOWN` is a legitimate prior state — the first activation
after startup must still fire the trigger.
"""
_schema: vol.Schema = ENTITY_STATE_TRIGGER_SCHEMA
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is available and the state has changed.
STATE_UNKNOWN is allowed as the origin state so the first
activation fires.
"""
if from_state.state == STATE_UNAVAILABLE:
return False
return from_state.state != to_state.state
def is_valid_state(self, state: State) -> bool:
"""Check that the entity has been activated at least once."""
return state.state not in self._excluded_states
_excluded_from_states: ClassVar[frozenset[str]] = frozenset({STATE_UNAVAILABLE})
NUMERICAL_ATTRIBUTE_CHANGED_TRIGGER_SCHEMA = ENTITY_STATE_TRIGGER_SCHEMA.extend(
@@ -826,10 +834,7 @@ class EntityNumericalStateChangedTriggerBase(EntityNumericalStateTriggerBase):
_schema = NUMERICAL_ATTRIBUTE_CHANGED_TRIGGER_SCHEMA
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is valid and the state has changed."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
return False
"""Check if the tracked numeric value has changed."""
return self._get_tracked_value(from_state) != self._get_tracked_value(to_state)
@@ -888,10 +893,7 @@ class EntityNumericalStateCrossedThresholdTriggerBase(EntityNumericalStateTrigge
_schema = NUMERICAL_ATTRIBUTE_CROSSED_THRESHOLD_SCHEMA
def is_valid_transition(self, from_state: State, to_state: State) -> bool:
"""Check if the origin state is valid and the state has changed."""
if from_state.state in (STATE_UNAVAILABLE, STATE_UNKNOWN):
return False
"""Check that the tracked value crossed into the threshold range."""
return not self.is_valid_state(from_state)
-8
View File
@@ -2969,10 +2969,6 @@ async def test_make_entity_target_state_trigger(
# Value did not change — not a valid transition
assert not trig.is_valid_transition(from_state, from_state)
# From unavailable — not valid
unavailable = State("light.bed", STATE_UNAVAILABLE, {})
assert not trig.is_valid_transition(unavailable, to_state)
# Value not in to_states — not valid
assert not trig.is_valid_state(wrong_value_state)
@@ -3043,10 +3039,6 @@ async def test_make_entity_transition_trigger(
# No change in tracked value — not a valid transition
assert not trig.is_valid_transition(from_state, from_state)
# From unavailable — not valid
unavailable = State("climate.living", STATE_UNAVAILABLE, {})
assert not trig.is_valid_transition(unavailable, to_state)
@pytest.mark.parametrize(
("domain_specs", "origin", "from_state", "to_state", "wrong_from"),