Compare commits

...

2 Commits

Author SHA1 Message Date
Erik
45f1fc32e2 Address comments 2026-01-08 09:36:34 +01:00
Erik
1e95598c96 Add helper for creating entity condition tests 2026-01-07 14:58:50 +01:00
2 changed files with 146 additions and 89 deletions

View File

@@ -172,6 +172,88 @@ class StateDescription(TypedDict):
count: int
class ConditionStateDescription(TypedDict):
"""Test state and expected service call count."""
included: _StateDescription
excluded: _StateDescription
count: int
valid: bool
def parametrize_condition_states(
*,
condition: str,
condition_options: dict[str, Any] | None = None,
target_states: list[str | None | tuple[str | None, dict]],
other_states: list[str | None | tuple[str | None, dict]],
additional_attributes: dict | None = None,
) -> list[tuple[str, dict[str, Any], list[ConditionStateDescription]]]:
"""Parametrize states and expected service call counts.
The target_states and other_states iterables are either iterables of
states or iterables of (state, attributes) tuples.
Returns a list of tuples with (trigger, list of states),
where states is a list of ConditionStateDescription dicts.
"""
additional_attributes = additional_attributes or {}
condition_options = condition_options or {}
def state_with_attributes(
state: str | None | tuple[str | None, dict], count: int, valid: bool
) -> ConditionStateDescription:
"""Return (state, attributes) dict."""
if isinstance(state, str) or state is None:
return {
"included": {
"state": state,
"attributes": additional_attributes,
},
"excluded": {
"state": state,
"attributes": {},
},
"count": count,
"valid": valid,
}
return {
"included": {
"state": state[0],
"attributes": state[1] | additional_attributes,
},
"excluded": {
"state": state[0],
"attributes": state[1],
},
"count": count,
"valid": valid,
}
return [
(
condition,
condition_options,
list(
itertools.chain(
(state_with_attributes(None, 0, False),),
(state_with_attributes(STATE_UNAVAILABLE, 0, False),),
(state_with_attributes(STATE_UNKNOWN, 0, False),),
(
state_with_attributes(other_state, 0, True)
for other_state in other_states
),
(
state_with_attributes(target_state, 1, True)
for target_state in target_states
),
)
),
),
]
def parametrize_trigger_states(
*,
trigger: str,
@@ -202,7 +284,7 @@ def parametrize_trigger_states(
def state_with_attributes(
state: str | None | tuple[str | None, dict], count: int
) -> dict:
) -> StateDescription:
"""Return (state, attributes) dict."""
if isinstance(state, str) or state is None:
return {

View File

@@ -1,6 +1,7 @@
"""Test light conditions."""
from collections.abc import Generator
from typing import Any
from unittest.mock import patch
import pytest
@@ -13,24 +14,18 @@ from homeassistant.const import (
CONF_TARGET,
STATE_OFF,
STATE_ON,
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.setup import async_setup_component
from tests.components import (
ConditionStateDescription,
parametrize_condition_states,
parametrize_target_entities,
set_or_remove_state,
target_entities,
)
INVALID_STATES = [
{"state": STATE_UNAVAILABLE, "attributes": {}},
{"state": STATE_UNKNOWN, "attributes": {}},
{"state": None, "attributes": {}},
]
@pytest.fixture(autouse=True, name="stub_blueprint_populate")
def stub_blueprint_populate_autouse(stub_blueprint_populate: None) -> None:
@@ -76,15 +71,15 @@ async def setup_automation_with_light_condition(
)
async def has_call_after_trigger(
async def calls_after_trigger(
hass: HomeAssistant, service_calls: list[ServiceCall]
) -> bool:
"""Check if there are service calls after the trigger event."""
) -> int:
"""Return number of service calls after the trigger event."""
hass.bus.async_fire("test_event")
await hass.async_block_till_done()
has_calls = len(service_calls) == 1
num_calls = len(service_calls)
service_calls.clear()
return has_calls
return num_calls
@pytest.fixture(name="enable_experimental_triggers_conditions")
@@ -125,17 +120,17 @@ async def test_light_conditions_gated_by_labs_flag(
parametrize_target_entities("light"),
)
@pytest.mark.parametrize(
("condition", "target_state", "other_state"),
("condition", "condition_options", "states"),
[
(
"light.is_on",
{"state": STATE_ON, "attributes": {}},
{"state": STATE_OFF, "attributes": {}},
*parametrize_condition_states(
condition="light.is_on",
target_states=[STATE_ON],
other_states=[STATE_OFF],
),
(
"light.is_off",
{"state": STATE_OFF, "attributes": {}},
{"state": STATE_ON, "attributes": {}},
*parametrize_condition_states(
condition="light.is_off",
target_states=[STATE_OFF],
other_states=[STATE_ON],
),
],
)
@@ -148,15 +143,15 @@ async def test_light_state_condition_behavior_any(
entity_id: str,
entities_in_target: int,
condition: str,
target_state: str,
other_state: str,
condition_options: dict[str, Any],
states: list[ConditionStateDescription],
) -> None:
"""Test the light state condition with the 'any' behavior."""
other_entity_ids = set(target_lights) - {entity_id}
# Set all lights, including the tested light, to the initial state
for eid in target_lights:
set_or_remove_state(hass, eid, other_state)
set_or_remove_state(hass, eid, states[0]["included"])
await hass.async_block_till_done()
await setup_automation_with_light_condition(
@@ -167,38 +162,23 @@ async def test_light_state_condition_behavior_any(
)
# Set state for switches to ensure that they don't impact the condition
for eid in target_switches:
set_or_remove_state(hass, eid, other_state)
await hass.async_block_till_done()
assert not await has_call_after_trigger(hass, service_calls)
for state in states:
for eid in target_switches:
set_or_remove_state(hass, eid, state["included"])
await hass.async_block_till_done()
assert await calls_after_trigger(hass, service_calls) == 0
for eid in target_switches:
set_or_remove_state(hass, eid, target_state)
await hass.async_block_till_done()
assert not await has_call_after_trigger(hass, service_calls)
for state in states:
included_state = state["included"]
set_or_remove_state(hass, entity_id, included_state)
await hass.async_block_till_done()
assert await calls_after_trigger(hass, service_calls) == state["count"]
# Set one light to the condition state -> condition pass
set_or_remove_state(hass, entity_id, target_state)
assert await has_call_after_trigger(hass, service_calls)
# Set all remaining lights to the condition state -> condition pass
for eid in other_entity_ids:
set_or_remove_state(hass, eid, target_state)
assert await has_call_after_trigger(hass, service_calls)
for invalid_state in INVALID_STATES:
# Set one light to the invalid state -> condition pass if there are
# other lights in the condition state
set_or_remove_state(hass, entity_id, invalid_state)
assert await has_call_after_trigger(hass, service_calls) == bool(
entities_in_target - 1
)
for invalid_state in INVALID_STATES:
# Set all lights to invalid state -> condition fail
for eid in other_entity_ids:
set_or_remove_state(hass, eid, invalid_state)
assert not await has_call_after_trigger(hass, service_calls)
# Check if changing other lights also passes the condition
for other_entity_id in other_entity_ids:
set_or_remove_state(hass, other_entity_id, included_state)
await hass.async_block_till_done()
assert await calls_after_trigger(hass, service_calls) == state["count"]
@pytest.mark.usefixtures("enable_experimental_triggers_conditions")
@@ -207,17 +187,17 @@ async def test_light_state_condition_behavior_any(
parametrize_target_entities("light"),
)
@pytest.mark.parametrize(
("condition", "target_state", "other_state"),
("condition", "condition_options", "states"),
[
(
"light.is_on",
{"state": STATE_ON, "attributes": {}},
{"state": STATE_OFF, "attributes": {}},
*parametrize_condition_states(
condition="light.is_on",
target_states=[STATE_ON],
other_states=[STATE_OFF],
),
(
"light.is_off",
{"state": STATE_OFF, "attributes": {}},
{"state": STATE_ON, "attributes": {}},
*parametrize_condition_states(
condition="light.is_off",
target_states=[STATE_OFF],
other_states=[STATE_ON],
),
],
)
@@ -229,8 +209,8 @@ async def test_light_state_condition_behavior_all(
entity_id: str,
entities_in_target: int,
condition: str,
target_state: str,
other_state: str,
condition_options: dict[str, Any],
states: list[ConditionStateDescription],
) -> None:
"""Test the light state condition with the 'all' behavior."""
# Set state for two switches to ensure that they don't impact the condition
@@ -241,7 +221,7 @@ async def test_light_state_condition_behavior_all(
# Set all lights, including the tested light, to the initial state
for eid in target_lights:
set_or_remove_state(hass, eid, other_state)
set_or_remove_state(hass, eid, states[0]["included"])
await hass.async_block_till_done()
await setup_automation_with_light_condition(
@@ -251,27 +231,22 @@ async def test_light_state_condition_behavior_all(
behavior="all",
)
# No lights on the condition state
assert not await has_call_after_trigger(hass, service_calls)
for state in states:
included_state = state["included"]
# Set one light to the condition state -> condition fail
set_or_remove_state(hass, entity_id, target_state)
assert await has_call_after_trigger(hass, service_calls) == (
entities_in_target == 1
)
set_or_remove_state(hass, entity_id, included_state)
await hass.async_block_till_done()
# The condition passes if all entities are either in a target state or invalid
assert await calls_after_trigger(hass, service_calls) == (
(not state["valid"]) or (state["count"] if entities_in_target == 1 else 0)
)
# Set all remaining lights to the condition state -> condition pass
for eid in other_entity_ids:
set_or_remove_state(hass, eid, target_state)
assert await has_call_after_trigger(hass, service_calls)
for other_entity_id in other_entity_ids:
set_or_remove_state(hass, other_entity_id, included_state)
await hass.async_block_till_done()
for invalid_state in INVALID_STATES:
# Set one light to the invalid state -> condition still pass
set_or_remove_state(hass, entity_id, invalid_state)
assert await has_call_after_trigger(hass, service_calls)
for invalid_state in INVALID_STATES:
# Set all lights to unavailable -> condition passes
for eid in other_entity_ids:
set_or_remove_state(hass, eid, invalid_state)
assert await has_call_after_trigger(hass, service_calls)
# The condition passes if all entities are either in a target state or invalid
assert (
await calls_after_trigger(hass, service_calls) == (not state["valid"])
or state["count"]
)