mirror of
https://github.com/home-assistant/core.git
synced 2026-04-20 16:39:02 +02:00
Abort USB discovery flows on device unplug (#156303)
This commit is contained in:
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Sequence
|
||||
import dataclasses
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
import logging
|
||||
@@ -45,7 +44,7 @@ from .utils import (
|
||||
usb_device_from_path, # noqa: F401
|
||||
usb_device_from_port, # noqa: F401
|
||||
usb_device_matches_matcher,
|
||||
usb_service_info_from_device, # noqa: F401
|
||||
usb_service_info_from_device,
|
||||
usb_unique_id_from_service_info, # noqa: F401
|
||||
)
|
||||
|
||||
@@ -59,7 +58,6 @@ ADD_REMOVE_SCAN_COOLDOWN = 5 # 5 second cooldown to give devices a chance to re
|
||||
|
||||
__all__ = [
|
||||
"USBCallbackMatcher",
|
||||
"async_is_plugged_in",
|
||||
"async_register_port_event_callback",
|
||||
"async_register_scan_request_callback",
|
||||
]
|
||||
@@ -101,51 +99,6 @@ def async_register_port_event_callback(
|
||||
return discovery.async_register_port_event_callback(callback)
|
||||
|
||||
|
||||
@hass_callback
|
||||
def async_is_plugged_in(hass: HomeAssistant, matcher: USBCallbackMatcher) -> bool:
|
||||
"""Return True is a USB device is present."""
|
||||
|
||||
vid = matcher.get("vid", "")
|
||||
pid = matcher.get("pid", "")
|
||||
serial_number = matcher.get("serial_number", "")
|
||||
manufacturer = matcher.get("manufacturer", "")
|
||||
description = matcher.get("description", "")
|
||||
|
||||
if (
|
||||
vid != vid.upper()
|
||||
or pid != pid.upper()
|
||||
or serial_number != serial_number.lower()
|
||||
or manufacturer != manufacturer.lower()
|
||||
or description != description.lower()
|
||||
):
|
||||
raise ValueError(
|
||||
f"vid and pid must be uppercase, the rest lowercase in matcher {matcher!r}"
|
||||
)
|
||||
|
||||
usb_discovery: USBDiscovery = hass.data[DOMAIN]
|
||||
return any(
|
||||
usb_device_matches_matcher(
|
||||
USBDevice(
|
||||
device=device,
|
||||
vid=vid,
|
||||
pid=pid,
|
||||
serial_number=serial_number,
|
||||
manufacturer=manufacturer,
|
||||
description=description,
|
||||
),
|
||||
matcher,
|
||||
)
|
||||
for (
|
||||
device,
|
||||
vid,
|
||||
pid,
|
||||
serial_number,
|
||||
manufacturer,
|
||||
description,
|
||||
) in usb_discovery.seen
|
||||
)
|
||||
|
||||
|
||||
@hass_callback
|
||||
def async_get_usb_matchers_for_device(
|
||||
hass: HomeAssistant, device: USBDevice
|
||||
@@ -244,7 +197,6 @@ class USBDiscovery:
|
||||
"""Init USB Discovery."""
|
||||
self.hass = hass
|
||||
self.usb = usb
|
||||
self.seen: set[tuple[str, ...]] = set()
|
||||
self.observer_active = False
|
||||
self._request_debouncer: Debouncer[Coroutine[Any, Any, None]] | None = None
|
||||
self._add_remove_debouncer: Debouncer[Coroutine[Any, Any, None]] | None = None
|
||||
@@ -393,30 +345,13 @@ class USBDiscovery:
|
||||
async def _async_process_discovered_usb_device(self, device: USBDevice) -> None:
|
||||
"""Process a USB discovery."""
|
||||
_LOGGER.debug("Discovered USB Device: %s", device)
|
||||
device_tuple = dataclasses.astuple(device)
|
||||
if device_tuple in self.seen:
|
||||
return
|
||||
self.seen.add(device_tuple)
|
||||
|
||||
matched = self.async_get_usb_matchers_for_device(device)
|
||||
if not matched:
|
||||
return
|
||||
|
||||
service_info: _UsbServiceInfo | None = None
|
||||
service_info = usb_service_info_from_device(device)
|
||||
|
||||
for matcher in matched:
|
||||
if service_info is None:
|
||||
service_info = _UsbServiceInfo(
|
||||
device=await self.hass.async_add_executor_job(
|
||||
get_serial_by_id, device.device
|
||||
),
|
||||
vid=device.vid,
|
||||
pid=device.pid,
|
||||
serial_number=device.serial_number,
|
||||
manufacturer=device.manufacturer,
|
||||
description=device.description,
|
||||
)
|
||||
|
||||
discovery_flow.async_create_flow(
|
||||
self.hass,
|
||||
matcher["domain"],
|
||||
@@ -424,6 +359,26 @@ class USBDiscovery:
|
||||
service_info,
|
||||
)
|
||||
|
||||
async def _async_process_removed_usb_device(self, device: USBDevice) -> None:
|
||||
"""Process a USB removal."""
|
||||
_LOGGER.debug("Removed USB Device: %s", device)
|
||||
matched = self.async_get_usb_matchers_for_device(device)
|
||||
if not matched:
|
||||
return
|
||||
|
||||
service_info = usb_service_info_from_device(device)
|
||||
|
||||
for matcher in matched:
|
||||
for flow in self.hass.config_entries.flow.async_progress_by_init_data_type(
|
||||
_UsbServiceInfo,
|
||||
lambda flow_service_info: flow_service_info == service_info,
|
||||
):
|
||||
if matcher["domain"] != flow["handler"]:
|
||||
continue
|
||||
|
||||
_LOGGER.debug("Aborting existing flow %s", flow["flow_id"])
|
||||
self.hass.config_entries.flow.async_abort(flow["flow_id"])
|
||||
|
||||
async def _async_process_ports(self, usb_devices: Sequence[USBDevice]) -> None:
|
||||
"""Process each discovered port."""
|
||||
_LOGGER.debug("USB devices: %r", usb_devices)
|
||||
@@ -464,7 +419,10 @@ class USBDiscovery:
|
||||
except Exception:
|
||||
_LOGGER.exception("Error in USB port event callback")
|
||||
|
||||
for usb_device in filtered_usb_devices:
|
||||
for usb_device in removed_devices:
|
||||
await self._async_process_removed_usb_device(usb_device)
|
||||
|
||||
for usb_device in added_devices:
|
||||
await self._async_process_discovered_usb_device(usb_device)
|
||||
|
||||
@hass_callback
|
||||
|
||||
@@ -9,6 +9,7 @@ from unittest.mock import MagicMock, Mock, call, patch, sentinel
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components import usb
|
||||
from homeassistant.components.usb.models import USBDevice
|
||||
from homeassistant.components.usb.utils import scan_serial_ports, usb_device_from_path
|
||||
@@ -23,7 +24,14 @@ from . import (
|
||||
patch_scanned_serial_ports,
|
||||
)
|
||||
|
||||
from tests.common import async_fire_time_changed, import_and_test_deprecated_constant
|
||||
from tests.common import (
|
||||
MockModule,
|
||||
async_fire_time_changed,
|
||||
import_and_test_deprecated_constant,
|
||||
mock_config_flow,
|
||||
mock_integration,
|
||||
mock_platform,
|
||||
)
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
conbee_device = USBDevice(
|
||||
@@ -880,84 +888,6 @@ def test_human_readable_device_name() -> None:
|
||||
assert "8A2A" in name
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("force_usb_polling_watcher")
|
||||
async def test_async_is_plugged_in(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
|
||||
) -> None:
|
||||
"""Test async_is_plugged_in."""
|
||||
new_usb = [{"domain": "test1", "vid": "3039", "pid": "3039"}]
|
||||
|
||||
mock_ports = [
|
||||
USBDevice(
|
||||
device=slae_sh_device.device,
|
||||
vid="3039",
|
||||
pid="3039",
|
||||
serial_number=slae_sh_device.serial_number,
|
||||
manufacturer=slae_sh_device.manufacturer,
|
||||
description=slae_sh_device.description,
|
||||
)
|
||||
]
|
||||
|
||||
matcher = {
|
||||
"vid": "3039",
|
||||
"pid": "3039",
|
||||
}
|
||||
|
||||
with (
|
||||
patch("homeassistant.components.usb.async_get_usb", return_value=new_usb),
|
||||
patch_scanned_serial_ports(return_value=[]),
|
||||
patch.object(hass.config_entries.flow, "async_init"),
|
||||
):
|
||||
assert await async_setup_component(hass, "usb", {"usb": {}})
|
||||
await hass.async_block_till_done()
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||
await hass.async_block_till_done()
|
||||
assert not usb.async_is_plugged_in(hass, matcher)
|
||||
|
||||
with (
|
||||
patch_scanned_serial_ports(return_value=mock_ports),
|
||||
patch.object(hass.config_entries.flow, "async_init"),
|
||||
):
|
||||
ws_client = await hass_ws_client(hass)
|
||||
await ws_client.send_json({"id": 1, "type": "usb/scan"})
|
||||
response = await ws_client.receive_json()
|
||||
assert response["success"]
|
||||
await hass.async_block_till_done()
|
||||
assert usb.async_is_plugged_in(hass, matcher)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("force_usb_polling_watcher")
|
||||
@pytest.mark.parametrize(
|
||||
"matcher",
|
||||
[
|
||||
{"vid": "abcd"},
|
||||
{"pid": "123a"},
|
||||
{"serial_number": "1234ABCD"},
|
||||
{"manufacturer": "Some Manufacturer"},
|
||||
{"description": "A description"},
|
||||
],
|
||||
)
|
||||
async def test_async_is_plugged_in_case_enforcement(
|
||||
hass: HomeAssistant, matcher
|
||||
) -> None:
|
||||
"""Test `async_is_plugged_in` throws an error when incorrect cases are used."""
|
||||
|
||||
new_usb = [{"domain": "test1", "vid": "ABCD"}]
|
||||
|
||||
with (
|
||||
patch("homeassistant.components.usb.async_get_usb", return_value=new_usb),
|
||||
patch_scanned_serial_ports(return_value=[]),
|
||||
patch.object(hass.config_entries.flow, "async_init"),
|
||||
):
|
||||
assert await async_setup_component(hass, "usb", {"usb": {}})
|
||||
await hass.async_block_till_done()
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
usb.async_is_plugged_in(hass, matcher)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("force_usb_polling_watcher")
|
||||
async def test_web_socket_triggers_discovery_request_callbacks(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
|
||||
@@ -1055,48 +985,6 @@ async def test_cancel_initial_scan_callback(
|
||||
assert len(mock_callback.mock_calls) == 0
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("force_usb_polling_watcher")
|
||||
async def test_resolve_serial_by_id(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
|
||||
) -> None:
|
||||
"""Test the discovery data resolves to serial/by-id."""
|
||||
new_usb = [{"domain": "test1", "vid": "3039", "pid": "3039"}]
|
||||
|
||||
mock_ports = [
|
||||
USBDevice(
|
||||
device=slae_sh_device.device,
|
||||
vid="3039",
|
||||
pid="3039",
|
||||
serial_number=slae_sh_device.serial_number,
|
||||
manufacturer=slae_sh_device.manufacturer,
|
||||
description=slae_sh_device.description,
|
||||
)
|
||||
]
|
||||
|
||||
with (
|
||||
patch("homeassistant.components.usb.async_get_usb", return_value=new_usb),
|
||||
patch_scanned_serial_ports(return_value=mock_ports),
|
||||
patch(
|
||||
"homeassistant.components.usb.get_serial_by_id",
|
||||
return_value="/dev/serial/by-id/bla",
|
||||
),
|
||||
patch.object(hass.config_entries.flow, "async_init") as mock_config_flow,
|
||||
):
|
||||
assert await async_setup_component(hass, "usb", {"usb": {}})
|
||||
await hass.async_block_till_done()
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||
await hass.async_block_till_done()
|
||||
ws_client = await hass_ws_client(hass)
|
||||
await ws_client.send_json({"id": 1, "type": "usb/scan"})
|
||||
response = await ws_client.receive_json()
|
||||
assert response["success"]
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(mock_config_flow.mock_calls) == 1
|
||||
assert mock_config_flow.mock_calls[0][1][0] == "test1"
|
||||
assert mock_config_flow.mock_calls[0][2]["data"].device == "/dev/serial/by-id/bla"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("force_usb_polling_watcher")
|
||||
@pytest.mark.parametrize(
|
||||
"ports",
|
||||
@@ -1535,3 +1423,172 @@ def test_usb_device_from_path_returns_none_when_not_found() -> None:
|
||||
result = usb_device_from_path("/dev/ttyUSB99")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("force_usb_polling_watcher")
|
||||
@patch("homeassistant.components.usb.REQUEST_SCAN_COOLDOWN", 0)
|
||||
async def test_removal_aborts_discovery_flows(
|
||||
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
|
||||
) -> None:
|
||||
"""Test USB device removal aborts the correct discovery flows."""
|
||||
# Used by test1
|
||||
device1 = USBDevice(
|
||||
device="/dev/serial/by-id/unique-device-1",
|
||||
vid="1234",
|
||||
pid="5678",
|
||||
serial_number="ABC123",
|
||||
manufacturer="Test Manufacturer 1",
|
||||
description="Test Device 1 for domain test1",
|
||||
)
|
||||
|
||||
# Used by test1
|
||||
device2 = USBDevice(
|
||||
device="/dev/serial/by-id/unique-device-2",
|
||||
vid="ABCD",
|
||||
pid="EF01",
|
||||
serial_number="XYZ789",
|
||||
manufacturer="Test Manufacturer 2",
|
||||
description="Test Device 2 for domain test1",
|
||||
)
|
||||
|
||||
# Used by test2
|
||||
device3 = USBDevice(
|
||||
device="/dev/serial/by-id/unique-device-3",
|
||||
vid="AAAA",
|
||||
pid="BBBB",
|
||||
serial_number="ABCDEF",
|
||||
manufacturer="Test Manufacturer 3",
|
||||
description="Test Device 3 for domain test2",
|
||||
)
|
||||
|
||||
# Not used by any domain
|
||||
device4 = USBDevice(
|
||||
device="/dev/serial/by-id/unique-device-4",
|
||||
vid="CCCC",
|
||||
pid="DDDD",
|
||||
serial_number="ABCDEF",
|
||||
manufacturer="Test Manufacturer 4",
|
||||
description="Test Device 4",
|
||||
)
|
||||
|
||||
# Used by both test1 and test2
|
||||
device5 = USBDevice(
|
||||
device="/dev/serial/by-id/multi-domain-device",
|
||||
vid="FFFF",
|
||||
pid="EEEE",
|
||||
serial_number="MULTI123",
|
||||
manufacturer="Test Manufacturer 5",
|
||||
description="Device matching multiple domains",
|
||||
)
|
||||
|
||||
class TestFlow(config_entries.ConfigFlow):
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_usb(self, discovery_info):
|
||||
return self.async_show_form(step_id="confirm")
|
||||
|
||||
async def async_step_confirm(self, user_input=None):
|
||||
# There's no way to exit
|
||||
return self.async_show_form(step_id="confirm")
|
||||
|
||||
mock_integration(hass, MockModule("test1"))
|
||||
mock_platform(hass, "test1.config_flow", None)
|
||||
|
||||
mock_integration(hass, MockModule("test2"))
|
||||
mock_platform(hass, "test2.config_flow", None)
|
||||
|
||||
ws_client = await hass_ws_client(hass)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.usb.async_get_usb",
|
||||
return_value=[
|
||||
# Domain `test1` matches devices 1 and 2
|
||||
{"domain": "test1", "vid": "1234", "pid": "5678"},
|
||||
{"domain": "test1", "vid": "ABCD", "pid": "EF01"},
|
||||
# Domain `test2` matches device 3
|
||||
{"domain": "test2", "vid": "AAAA", "pid": "BBBB"},
|
||||
# Both domains match device 5
|
||||
{"domain": "test1", "vid": "FFFF", "pid": "EEEE"},
|
||||
{"domain": "test2", "vid": "FFFF", "pid": "EEEE"},
|
||||
],
|
||||
),
|
||||
# All devices are plugged in initially
|
||||
patch_scanned_serial_ports(
|
||||
return_value=[device1, device2, device3, device4, device5]
|
||||
),
|
||||
mock_config_flow("test1", TestFlow),
|
||||
mock_config_flow("test2", TestFlow),
|
||||
):
|
||||
assert await async_setup_component(hass, "usb", {"usb": {}})
|
||||
await hass.async_block_till_done()
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Discovery will create five flows (device5 is matched by both domains)
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 5
|
||||
|
||||
# Three flows for test1 (1, 2, 5), two for test2 (3, 5)
|
||||
assert sorted([flow["handler"] for flow in flows]) == [
|
||||
"test1",
|
||||
"test1",
|
||||
"test1",
|
||||
"test2",
|
||||
"test2",
|
||||
]
|
||||
|
||||
# Device 5 is removed
|
||||
with patch_scanned_serial_ports(
|
||||
return_value=[device1, device2, device3, device4]
|
||||
):
|
||||
await ws_client.send_json({"id": 1, "type": "usb/scan"})
|
||||
response = await ws_client.receive_json()
|
||||
assert response["success"]
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Both flows for device5 should be aborted (one test1, one test2)
|
||||
remaining_flows = hass.config_entries.flow.async_progress()
|
||||
assert len(remaining_flows) == 3
|
||||
assert sorted([flow["handler"] for flow in remaining_flows]) == [
|
||||
"test1",
|
||||
"test1",
|
||||
"test2",
|
||||
]
|
||||
|
||||
# Device 3 disappears
|
||||
with patch_scanned_serial_ports(return_value=[device1, device2, device4]):
|
||||
await ws_client.send_json({"id": 2, "type": "usb/scan"})
|
||||
response = await ws_client.receive_json()
|
||||
assert response["success"]
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# The corresponding flow is removed
|
||||
remaining_flows = hass.config_entries.flow.async_progress()
|
||||
assert len(remaining_flows) == 2
|
||||
assert sorted([flow["handler"] for flow in remaining_flows]) == [
|
||||
"test1",
|
||||
"test1",
|
||||
]
|
||||
|
||||
# Remove the others
|
||||
with patch_scanned_serial_ports(return_value=[]):
|
||||
await ws_client.send_json({"id": 3, "type": "usb/scan"})
|
||||
response = await ws_client.receive_json()
|
||||
assert response["success"]
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# All the remaining flows should be aborted
|
||||
assert len(hass.config_entries.flow.async_progress()) == 0
|
||||
|
||||
# Plug one back in and the unused device4
|
||||
with patch_scanned_serial_ports(return_value=[device3, device4]):
|
||||
await ws_client.send_json({"id": 4, "type": "usb/scan"})
|
||||
response = await ws_client.receive_json()
|
||||
assert response["success"]
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# A new flow is re-created for the old device
|
||||
final_flows = hass.config_entries.flow.async_progress()
|
||||
assert len(final_flows) == 1
|
||||
assert final_flows[0]["handler"] == "test2"
|
||||
|
||||
Reference in New Issue
Block a user