mirror of
https://github.com/home-assistant/core.git
synced 2026-04-08 15:52:28 +02:00
Use satellite entity area in the assist pipeline (#153017)
This commit is contained in:
@@ -1308,7 +1308,9 @@ class PipelineRun:
|
||||
# instead of a full response.
|
||||
all_targets_in_satellite_area = (
|
||||
self._get_all_targets_in_satellite_area(
|
||||
conversation_result.response, self._device_id
|
||||
conversation_result.response,
|
||||
self._satellite_id,
|
||||
self._device_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1337,39 +1339,62 @@ class PipelineRun:
|
||||
return (speech, all_targets_in_satellite_area)
|
||||
|
||||
def _get_all_targets_in_satellite_area(
|
||||
self, intent_response: intent.IntentResponse, device_id: str | None
|
||||
self,
|
||||
intent_response: intent.IntentResponse,
|
||||
satellite_id: str | None,
|
||||
device_id: str | None,
|
||||
) -> bool:
|
||||
"""Return true if all targeted entities were in the same area as the device."""
|
||||
if (
|
||||
(intent_response.response_type != intent.IntentResponseType.ACTION_DONE)
|
||||
or (not intent_response.matched_states)
|
||||
or (not device_id)
|
||||
):
|
||||
return False
|
||||
|
||||
device_registry = dr.async_get(self.hass)
|
||||
|
||||
if (not (device := device_registry.async_get(device_id))) or (
|
||||
not device.area_id
|
||||
intent_response.response_type != intent.IntentResponseType.ACTION_DONE
|
||||
or not intent_response.matched_states
|
||||
):
|
||||
return False
|
||||
|
||||
entity_registry = er.async_get(self.hass)
|
||||
for state in intent_response.matched_states:
|
||||
entity = entity_registry.async_get(state.entity_id)
|
||||
if not entity:
|
||||
device_registry = dr.async_get(self.hass)
|
||||
|
||||
area_id: str | None = None
|
||||
|
||||
if (
|
||||
satellite_id is not None
|
||||
and (target_entity_entry := entity_registry.async_get(satellite_id))
|
||||
is not None
|
||||
):
|
||||
area_id = target_entity_entry.area_id
|
||||
device_id = target_entity_entry.device_id
|
||||
|
||||
if area_id is None:
|
||||
if device_id is None:
|
||||
return False
|
||||
|
||||
if (entity_area_id := entity.area_id) is None:
|
||||
if (entity.device_id is None) or (
|
||||
(entity_device := device_registry.async_get(entity.device_id))
|
||||
is None
|
||||
):
|
||||
device_entry = device_registry.async_get(device_id)
|
||||
if device_entry is None:
|
||||
return False
|
||||
|
||||
area_id = device_entry.area_id
|
||||
if area_id is None:
|
||||
return False
|
||||
|
||||
for state in intent_response.matched_states:
|
||||
target_entity_entry = entity_registry.async_get(state.entity_id)
|
||||
if target_entity_entry is None:
|
||||
return False
|
||||
|
||||
target_area_id = target_entity_entry.area_id
|
||||
if target_area_id is None:
|
||||
if target_entity_entry.device_id is None:
|
||||
return False
|
||||
|
||||
entity_area_id = entity_device.area_id
|
||||
target_device_entry = device_registry.async_get(
|
||||
target_entity_entry.device_id
|
||||
)
|
||||
if target_device_entry is None:
|
||||
return False
|
||||
|
||||
if entity_area_id != device.area_id:
|
||||
target_area_id = target_device_entry.area_id
|
||||
|
||||
if target_area_id != area_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -1797,6 +1797,7 @@ async def test_chat_log_tts_streaming(
|
||||
assert process_events(events) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("use_satellite_entity"), [True, False])
|
||||
async def test_acknowledge(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
@@ -1805,6 +1806,7 @@ async def test_acknowledge(
|
||||
entity_registry: er.EntityRegistry,
|
||||
area_registry: ar.AreaRegistry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
use_satellite_entity: bool,
|
||||
) -> None:
|
||||
"""Test that acknowledge sound is played when targets are in the same area."""
|
||||
area_1 = area_registry.async_get_or_create("area_1")
|
||||
@@ -1819,12 +1821,16 @@ async def test_acknowledge(
|
||||
|
||||
entry = MockConfigEntry()
|
||||
entry.add_to_hass(hass)
|
||||
satellite = device_registry.async_get_or_create(
|
||||
|
||||
satellite = entity_registry.async_get_or_create("assist_satellite", "test", "1234")
|
||||
entity_registry.async_update_entity(satellite.entity_id, area_id=area_1.id)
|
||||
|
||||
satellite_device = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections=set(),
|
||||
identifiers={("demo", "id-1234")},
|
||||
)
|
||||
device_registry.async_update_device(satellite.id, area_id=area_1.id)
|
||||
device_registry.async_update_device(satellite_device.id, area_id=area_1.id)
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
turn_on = async_mock_service(hass, "light", "turn_on")
|
||||
@@ -1837,7 +1843,8 @@ async def test_acknowledge(
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input=text,
|
||||
session=mock_chat_session,
|
||||
device_id=satellite.id,
|
||||
satellite_id=satellite.entity_id if use_satellite_entity else None,
|
||||
device_id=satellite_device.id if not use_satellite_entity else None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
@@ -1889,7 +1896,8 @@ async def test_acknowledge(
|
||||
)
|
||||
|
||||
# 3. Remove satellite device area
|
||||
device_registry.async_update_device(satellite.id, area_id=None)
|
||||
entity_registry.async_update_entity(satellite.entity_id, area_id=None)
|
||||
device_registry.async_update_device(satellite_device.id, area_id=None)
|
||||
|
||||
_reset()
|
||||
await _run("turn on light 1")
|
||||
@@ -1900,7 +1908,8 @@ async def test_acknowledge(
|
||||
assert len(turn_on) == 1
|
||||
|
||||
# Restore
|
||||
device_registry.async_update_device(satellite.id, area_id=area_1.id)
|
||||
entity_registry.async_update_entity(satellite.entity_id, area_id=area_1.id)
|
||||
device_registry.async_update_device(satellite_device.id, area_id=area_1.id)
|
||||
|
||||
# 4. Check device area instead of entity area
|
||||
light_device = device_registry.async_get_or_create(
|
||||
|
||||
Reference in New Issue
Block a user