Use satellite entity area in the assist pipeline (#153017)

This commit is contained in:
Artur Pragacz
2025-09-26 21:34:45 +02:00
committed by GitHub
parent a6c3f4efc0
commit 953895cd81
2 changed files with 61 additions and 27 deletions

View File

@@ -1308,7 +1308,9 @@ class PipelineRun:
# instead of a full response.
all_targets_in_satellite_area = (
self._get_all_targets_in_satellite_area(
conversation_result.response, self._device_id
conversation_result.response,
self._satellite_id,
self._device_id,
)
)
@@ -1337,39 +1339,62 @@ class PipelineRun:
return (speech, all_targets_in_satellite_area)
def _get_all_targets_in_satellite_area(
self, intent_response: intent.IntentResponse, device_id: str | None
self,
intent_response: intent.IntentResponse,
satellite_id: str | None,
device_id: str | None,
) -> bool:
"""Return true if all targeted entities were in the same area as the device."""
if (
(intent_response.response_type != intent.IntentResponseType.ACTION_DONE)
or (not intent_response.matched_states)
or (not device_id)
):
return False
device_registry = dr.async_get(self.hass)
if (not (device := device_registry.async_get(device_id))) or (
not device.area_id
intent_response.response_type != intent.IntentResponseType.ACTION_DONE
or not intent_response.matched_states
):
return False
entity_registry = er.async_get(self.hass)
for state in intent_response.matched_states:
entity = entity_registry.async_get(state.entity_id)
if not entity:
device_registry = dr.async_get(self.hass)
area_id: str | None = None
if (
satellite_id is not None
and (target_entity_entry := entity_registry.async_get(satellite_id))
is not None
):
area_id = target_entity_entry.area_id
device_id = target_entity_entry.device_id
if area_id is None:
if device_id is None:
return False
if (entity_area_id := entity.area_id) is None:
if (entity.device_id is None) or (
(entity_device := device_registry.async_get(entity.device_id))
is None
):
device_entry = device_registry.async_get(device_id)
if device_entry is None:
return False
area_id = device_entry.area_id
if area_id is None:
return False
for state in intent_response.matched_states:
target_entity_entry = entity_registry.async_get(state.entity_id)
if target_entity_entry is None:
return False
target_area_id = target_entity_entry.area_id
if target_area_id is None:
if target_entity_entry.device_id is None:
return False
entity_area_id = entity_device.area_id
target_device_entry = device_registry.async_get(
target_entity_entry.device_id
)
if target_device_entry is None:
return False
if entity_area_id != device.area_id:
target_area_id = target_device_entry.area_id
if target_area_id != area_id:
return False
return True

View File

@@ -1797,6 +1797,7 @@ async def test_chat_log_tts_streaming(
assert process_events(events) == snapshot
@pytest.mark.parametrize(("use_satellite_entity"), [True, False])
async def test_acknowledge(
hass: HomeAssistant,
init_components,
@@ -1805,6 +1806,7 @@ async def test_acknowledge(
entity_registry: er.EntityRegistry,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
use_satellite_entity: bool,
) -> None:
"""Test that acknowledge sound is played when targets are in the same area."""
area_1 = area_registry.async_get_or_create("area_1")
@@ -1819,12 +1821,16 @@ async def test_acknowledge(
entry = MockConfigEntry()
entry.add_to_hass(hass)
satellite = device_registry.async_get_or_create(
satellite = entity_registry.async_get_or_create("assist_satellite", "test", "1234")
entity_registry.async_update_entity(satellite.entity_id, area_id=area_1.id)
satellite_device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections=set(),
identifiers={("demo", "id-1234")},
)
device_registry.async_update_device(satellite.id, area_id=area_1.id)
device_registry.async_update_device(satellite_device.id, area_id=area_1.id)
events: list[assist_pipeline.PipelineEvent] = []
turn_on = async_mock_service(hass, "light", "turn_on")
@@ -1837,7 +1843,8 @@ async def test_acknowledge(
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input=text,
session=mock_chat_session,
device_id=satellite.id,
satellite_id=satellite.entity_id if use_satellite_entity else None,
device_id=satellite_device.id if not use_satellite_entity else None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
@@ -1889,7 +1896,8 @@ async def test_acknowledge(
)
# 3. Remove satellite device area
device_registry.async_update_device(satellite.id, area_id=None)
entity_registry.async_update_entity(satellite.entity_id, area_id=None)
device_registry.async_update_device(satellite_device.id, area_id=None)
_reset()
await _run("turn on light 1")
@@ -1900,7 +1908,8 @@ async def test_acknowledge(
assert len(turn_on) == 1
# Restore
device_registry.async_update_device(satellite.id, area_id=area_1.id)
entity_registry.async_update_entity(satellite.entity_id, area_id=area_1.id)
device_registry.async_update_device(satellite_device.id, area_id=area_1.id)
# 4. Check device area instead of entity area
light_device = device_registry.async_get_or_create(