Abort USB discovery flows on device unplug (#156303)

This commit is contained in:
puddly
2025-11-26 03:00:41 -05:00
committed by GitHub
parent c41493860d
commit 1c0dd02a7c
2 changed files with 204 additions and 189 deletions

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Sequence
import dataclasses
from datetime import datetime, timedelta
from functools import partial
import logging
@@ -45,7 +44,7 @@ from .utils import (
usb_device_from_path, # noqa: F401
usb_device_from_port, # noqa: F401
usb_device_matches_matcher,
usb_service_info_from_device, # noqa: F401
usb_service_info_from_device,
usb_unique_id_from_service_info, # noqa: F401
)
@@ -59,7 +58,6 @@ ADD_REMOVE_SCAN_COOLDOWN = 5 # 5 second cooldown to give devices a chance to re
__all__ = [
"USBCallbackMatcher",
"async_is_plugged_in",
"async_register_port_event_callback",
"async_register_scan_request_callback",
]
@@ -101,51 +99,6 @@ def async_register_port_event_callback(
return discovery.async_register_port_event_callback(callback)
@hass_callback
def async_is_plugged_in(hass: HomeAssistant, matcher: USBCallbackMatcher) -> bool:
"""Return True is a USB device is present."""
vid = matcher.get("vid", "")
pid = matcher.get("pid", "")
serial_number = matcher.get("serial_number", "")
manufacturer = matcher.get("manufacturer", "")
description = matcher.get("description", "")
if (
vid != vid.upper()
or pid != pid.upper()
or serial_number != serial_number.lower()
or manufacturer != manufacturer.lower()
or description != description.lower()
):
raise ValueError(
f"vid and pid must be uppercase, the rest lowercase in matcher {matcher!r}"
)
usb_discovery: USBDiscovery = hass.data[DOMAIN]
return any(
usb_device_matches_matcher(
USBDevice(
device=device,
vid=vid,
pid=pid,
serial_number=serial_number,
manufacturer=manufacturer,
description=description,
),
matcher,
)
for (
device,
vid,
pid,
serial_number,
manufacturer,
description,
) in usb_discovery.seen
)
@hass_callback
def async_get_usb_matchers_for_device(
hass: HomeAssistant, device: USBDevice
@@ -244,7 +197,6 @@ class USBDiscovery:
"""Init USB Discovery."""
self.hass = hass
self.usb = usb
self.seen: set[tuple[str, ...]] = set()
self.observer_active = False
self._request_debouncer: Debouncer[Coroutine[Any, Any, None]] | None = None
self._add_remove_debouncer: Debouncer[Coroutine[Any, Any, None]] | None = None
@@ -393,30 +345,13 @@ class USBDiscovery:
async def _async_process_discovered_usb_device(self, device: USBDevice) -> None:
"""Process a USB discovery."""
_LOGGER.debug("Discovered USB Device: %s", device)
device_tuple = dataclasses.astuple(device)
if device_tuple in self.seen:
return
self.seen.add(device_tuple)
matched = self.async_get_usb_matchers_for_device(device)
if not matched:
return
service_info: _UsbServiceInfo | None = None
service_info = usb_service_info_from_device(device)
for matcher in matched:
if service_info is None:
service_info = _UsbServiceInfo(
device=await self.hass.async_add_executor_job(
get_serial_by_id, device.device
),
vid=device.vid,
pid=device.pid,
serial_number=device.serial_number,
manufacturer=device.manufacturer,
description=device.description,
)
discovery_flow.async_create_flow(
self.hass,
matcher["domain"],
@@ -424,6 +359,26 @@ class USBDiscovery:
service_info,
)
async def _async_process_removed_usb_device(self, device: USBDevice) -> None:
"""Process a USB removal."""
_LOGGER.debug("Removed USB Device: %s", device)
matched = self.async_get_usb_matchers_for_device(device)
if not matched:
return
service_info = usb_service_info_from_device(device)
for matcher in matched:
for flow in self.hass.config_entries.flow.async_progress_by_init_data_type(
_UsbServiceInfo,
lambda flow_service_info: flow_service_info == service_info,
):
if matcher["domain"] != flow["handler"]:
continue
_LOGGER.debug("Aborting existing flow %s", flow["flow_id"])
self.hass.config_entries.flow.async_abort(flow["flow_id"])
async def _async_process_ports(self, usb_devices: Sequence[USBDevice]) -> None:
"""Process each discovered port."""
_LOGGER.debug("USB devices: %r", usb_devices)
@@ -464,7 +419,10 @@ class USBDiscovery:
except Exception:
_LOGGER.exception("Error in USB port event callback")
for usb_device in filtered_usb_devices:
for usb_device in removed_devices:
await self._async_process_removed_usb_device(usb_device)
for usb_device in added_devices:
await self._async_process_discovered_usb_device(usb_device)
@hass_callback

View File

@@ -9,6 +9,7 @@ from unittest.mock import MagicMock, Mock, call, patch, sentinel
import pytest
from homeassistant import config_entries
from homeassistant.components import usb
from homeassistant.components.usb.models import USBDevice
from homeassistant.components.usb.utils import scan_serial_ports, usb_device_from_path
@@ -23,7 +24,14 @@ from . import (
patch_scanned_serial_ports,
)
from tests.common import async_fire_time_changed, import_and_test_deprecated_constant
from tests.common import (
MockModule,
async_fire_time_changed,
import_and_test_deprecated_constant,
mock_config_flow,
mock_integration,
mock_platform,
)
from tests.typing import WebSocketGenerator
conbee_device = USBDevice(
@@ -880,84 +888,6 @@ def test_human_readable_device_name() -> None:
assert "8A2A" in name
@pytest.mark.usefixtures("force_usb_polling_watcher")
async def test_async_is_plugged_in(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> None:
"""Test async_is_plugged_in."""
new_usb = [{"domain": "test1", "vid": "3039", "pid": "3039"}]
mock_ports = [
USBDevice(
device=slae_sh_device.device,
vid="3039",
pid="3039",
serial_number=slae_sh_device.serial_number,
manufacturer=slae_sh_device.manufacturer,
description=slae_sh_device.description,
)
]
matcher = {
"vid": "3039",
"pid": "3039",
}
with (
patch("homeassistant.components.usb.async_get_usb", return_value=new_usb),
patch_scanned_serial_ports(return_value=[]),
patch.object(hass.config_entries.flow, "async_init"),
):
assert await async_setup_component(hass, "usb", {"usb": {}})
await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
assert not usb.async_is_plugged_in(hass, matcher)
with (
patch_scanned_serial_ports(return_value=mock_ports),
patch.object(hass.config_entries.flow, "async_init"),
):
ws_client = await hass_ws_client(hass)
await ws_client.send_json({"id": 1, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done()
assert usb.async_is_plugged_in(hass, matcher)
@pytest.mark.usefixtures("force_usb_polling_watcher")
@pytest.mark.parametrize(
"matcher",
[
{"vid": "abcd"},
{"pid": "123a"},
{"serial_number": "1234ABCD"},
{"manufacturer": "Some Manufacturer"},
{"description": "A description"},
],
)
async def test_async_is_plugged_in_case_enforcement(
hass: HomeAssistant, matcher
) -> None:
"""Test `async_is_plugged_in` throws an error when incorrect cases are used."""
new_usb = [{"domain": "test1", "vid": "ABCD"}]
with (
patch("homeassistant.components.usb.async_get_usb", return_value=new_usb),
patch_scanned_serial_ports(return_value=[]),
patch.object(hass.config_entries.flow, "async_init"),
):
assert await async_setup_component(hass, "usb", {"usb": {}})
await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
with pytest.raises(ValueError):
usb.async_is_plugged_in(hass, matcher)
@pytest.mark.usefixtures("force_usb_polling_watcher")
async def test_web_socket_triggers_discovery_request_callbacks(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
@@ -1055,48 +985,6 @@ async def test_cancel_initial_scan_callback(
assert len(mock_callback.mock_calls) == 0
@pytest.mark.usefixtures("force_usb_polling_watcher")
async def test_resolve_serial_by_id(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> None:
"""Test the discovery data resolves to serial/by-id."""
new_usb = [{"domain": "test1", "vid": "3039", "pid": "3039"}]
mock_ports = [
USBDevice(
device=slae_sh_device.device,
vid="3039",
pid="3039",
serial_number=slae_sh_device.serial_number,
manufacturer=slae_sh_device.manufacturer,
description=slae_sh_device.description,
)
]
with (
patch("homeassistant.components.usb.async_get_usb", return_value=new_usb),
patch_scanned_serial_ports(return_value=mock_ports),
patch(
"homeassistant.components.usb.get_serial_by_id",
return_value="/dev/serial/by-id/bla",
),
patch.object(hass.config_entries.flow, "async_init") as mock_config_flow,
):
assert await async_setup_component(hass, "usb", {"usb": {}})
await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
ws_client = await hass_ws_client(hass)
await ws_client.send_json({"id": 1, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 1
assert mock_config_flow.mock_calls[0][1][0] == "test1"
assert mock_config_flow.mock_calls[0][2]["data"].device == "/dev/serial/by-id/bla"
@pytest.mark.usefixtures("force_usb_polling_watcher")
@pytest.mark.parametrize(
"ports",
@@ -1535,3 +1423,172 @@ def test_usb_device_from_path_returns_none_when_not_found() -> None:
result = usb_device_from_path("/dev/ttyUSB99")
assert result is None
@pytest.mark.usefixtures("force_usb_polling_watcher")
@patch("homeassistant.components.usb.REQUEST_SCAN_COOLDOWN", 0)
async def test_removal_aborts_discovery_flows(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> None:
"""Test USB device removal aborts the correct discovery flows."""
# Used by test1
device1 = USBDevice(
device="/dev/serial/by-id/unique-device-1",
vid="1234",
pid="5678",
serial_number="ABC123",
manufacturer="Test Manufacturer 1",
description="Test Device 1 for domain test1",
)
# Used by test1
device2 = USBDevice(
device="/dev/serial/by-id/unique-device-2",
vid="ABCD",
pid="EF01",
serial_number="XYZ789",
manufacturer="Test Manufacturer 2",
description="Test Device 2 for domain test1",
)
# Used by test2
device3 = USBDevice(
device="/dev/serial/by-id/unique-device-3",
vid="AAAA",
pid="BBBB",
serial_number="ABCDEF",
manufacturer="Test Manufacturer 3",
description="Test Device 3 for domain test2",
)
# Not used by any domain
device4 = USBDevice(
device="/dev/serial/by-id/unique-device-4",
vid="CCCC",
pid="DDDD",
serial_number="ABCDEF",
manufacturer="Test Manufacturer 4",
description="Test Device 4",
)
# Used by both test1 and test2
device5 = USBDevice(
device="/dev/serial/by-id/multi-domain-device",
vid="FFFF",
pid="EEEE",
serial_number="MULTI123",
manufacturer="Test Manufacturer 5",
description="Device matching multiple domains",
)
class TestFlow(config_entries.ConfigFlow):
VERSION = 1
async def async_step_usb(self, discovery_info):
return self.async_show_form(step_id="confirm")
async def async_step_confirm(self, user_input=None):
# There's no way to exit
return self.async_show_form(step_id="confirm")
mock_integration(hass, MockModule("test1"))
mock_platform(hass, "test1.config_flow", None)
mock_integration(hass, MockModule("test2"))
mock_platform(hass, "test2.config_flow", None)
ws_client = await hass_ws_client(hass)
with (
patch(
"homeassistant.components.usb.async_get_usb",
return_value=[
# Domain `test1` matches devices 1 and 2
{"domain": "test1", "vid": "1234", "pid": "5678"},
{"domain": "test1", "vid": "ABCD", "pid": "EF01"},
# Domain `test2` matches device 3
{"domain": "test2", "vid": "AAAA", "pid": "BBBB"},
# Both domains match device 5
{"domain": "test1", "vid": "FFFF", "pid": "EEEE"},
{"domain": "test2", "vid": "FFFF", "pid": "EEEE"},
],
),
# All devices are plugged in initially
patch_scanned_serial_ports(
return_value=[device1, device2, device3, device4, device5]
),
mock_config_flow("test1", TestFlow),
mock_config_flow("test2", TestFlow),
):
assert await async_setup_component(hass, "usb", {"usb": {}})
await hass.async_block_till_done()
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
# Discovery will create five flows (device5 is matched by both domains)
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 5
# Three flows for test1 (1, 2, 5), two for test2 (3, 5)
assert sorted([flow["handler"] for flow in flows]) == [
"test1",
"test1",
"test1",
"test2",
"test2",
]
# Device 5 is removed
with patch_scanned_serial_ports(
return_value=[device1, device2, device3, device4]
):
await ws_client.send_json({"id": 1, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done()
# Both flows for device5 should be aborted (one test1, one test2)
remaining_flows = hass.config_entries.flow.async_progress()
assert len(remaining_flows) == 3
assert sorted([flow["handler"] for flow in remaining_flows]) == [
"test1",
"test1",
"test2",
]
# Device 3 disappears
with patch_scanned_serial_ports(return_value=[device1, device2, device4]):
await ws_client.send_json({"id": 2, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done()
# The corresponding flow is removed
remaining_flows = hass.config_entries.flow.async_progress()
assert len(remaining_flows) == 2
assert sorted([flow["handler"] for flow in remaining_flows]) == [
"test1",
"test1",
]
# Remove the others
with patch_scanned_serial_ports(return_value=[]):
await ws_client.send_json({"id": 3, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done()
# All the remaining flows should be aborted
assert len(hass.config_entries.flow.async_progress()) == 0
# Plug one back in and the unused device4
with patch_scanned_serial_ports(return_value=[device3, device4]):
await ws_client.send_json({"id": 4, "type": "usb/scan"})
response = await ws_client.receive_json()
assert response["success"]
await hass.async_block_till_done()
# A new flow is re-created for the old device
final_flows = hass.config_entries.flow.async_progress()
assert len(final_flows) == 1
assert final_flows[0]["handler"] == "test2"