Prevent reloading the ZHA integration while adapter firmware is being updated (#152626)

This commit is contained in:
puddly
2025-10-09 15:00:02 -04:00
committed by Franck Nijhof
parent 5abaabc9da
commit 1d407d1326
9 changed files with 301 additions and 47 deletions
@@ -1,15 +1,20 @@
"""Home Assistant Hardware integration helpers."""
from __future__ import annotations
from collections import defaultdict
from collections.abc import AsyncIterator, Awaitable, Callable
from contextlib import asynccontextmanager
import logging
from typing import Protocol
from typing import TYPE_CHECKING, Protocol
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback as hass_callback
from . import DATA_COMPONENT
from .util import FirmwareInfo
if TYPE_CHECKING:
from .util import FirmwareInfo
_LOGGER = logging.getLogger(__name__)
@@ -51,6 +56,7 @@ class HardwareInfoDispatcher:
self._notification_callbacks: defaultdict[
str, set[Callable[[FirmwareInfo], None]]
] = defaultdict(set)
self._active_firmware_updates: dict[str, str] = {}
def register_firmware_info_provider(
self, domain: str, platform: HardwareFirmwareInfoModule
@@ -118,6 +124,36 @@ class HardwareInfoDispatcher:
if fw_info is not None:
yield fw_info
def register_firmware_update_in_progress(
self, device: str, source_domain: str
) -> None:
"""Register that a firmware update is in progress for a device."""
if device in self._active_firmware_updates:
current_domain = self._active_firmware_updates[device]
raise ValueError(
f"Firmware update already in progress for {device} by {current_domain}"
)
self._active_firmware_updates[device] = source_domain
def unregister_firmware_update_in_progress(
self, device: str, source_domain: str
) -> None:
"""Unregister a firmware update for a device."""
if device not in self._active_firmware_updates:
raise ValueError(f"No firmware update in progress for {device}")
if self._active_firmware_updates[device] != source_domain:
current_domain = self._active_firmware_updates[device]
raise ValueError(
f"Firmware update for {device} is owned by {current_domain}, not {source_domain}"
)
del self._active_firmware_updates[device]
def is_firmware_update_in_progress(self, device: str) -> bool:
"""Check if a firmware update is in progress for a device."""
return device in self._active_firmware_updates
@hass_callback
def async_register_firmware_info_provider(
@@ -141,3 +177,42 @@ def async_notify_firmware_info(
) -> Awaitable[None]:
"""Notify the dispatcher of new firmware information."""
return hass.data[DATA_COMPONENT].notify_firmware_info(domain, firmware_info)
@hass_callback
def async_register_firmware_update_in_progress(
hass: HomeAssistant, device: str, source_domain: str
) -> None:
"""Register that a firmware update is in progress for a device."""
return hass.data[DATA_COMPONENT].register_firmware_update_in_progress(
device, source_domain
)
@hass_callback
def async_unregister_firmware_update_in_progress(
hass: HomeAssistant, device: str, source_domain: str
) -> None:
"""Unregister a firmware update for a device."""
return hass.data[DATA_COMPONENT].unregister_firmware_update_in_progress(
device, source_domain
)
@hass_callback
def async_is_firmware_update_in_progress(hass: HomeAssistant, device: str) -> bool:
"""Check if a firmware update is in progress for a device."""
return hass.data[DATA_COMPONENT].is_firmware_update_in_progress(device)
@asynccontextmanager
async def async_firmware_update_context(
hass: HomeAssistant, device: str, source_domain: str
) -> AsyncIterator[None]:
"""Register a device as having its firmware being actively updated."""
async_register_firmware_update_in_progress(hass, device, source_domain)
try:
yield
finally:
async_unregister_firmware_update_in_progress(hass, device, source_domain)
@@ -275,6 +275,7 @@ class BaseFirmwareUpdateEntity(
expected_installed_firmware_type=self.entity_description.expected_firmware_type,
bootloader_reset_methods=self.bootloader_reset_methods,
progress_callback=self._update_progress,
domain=self._config_entry.domain,
)
finally:
self._attr_in_progress = False
@@ -26,6 +26,7 @@ from homeassistant.helpers.singleton import singleton
from . import DATA_COMPONENT
from .const import (
DOMAIN,
OTBR_ADDON_MANAGER_DATA,
OTBR_ADDON_NAME,
OTBR_ADDON_SLUG,
@@ -33,6 +34,7 @@ from .const import (
ZIGBEE_FLASHER_ADDON_NAME,
ZIGBEE_FLASHER_ADDON_SLUG,
)
from .helpers import async_firmware_update_context
from .silabs_multiprotocol_addon import (
WaitingAddonManager,
get_multiprotocol_addon_manager,
@@ -359,45 +361,50 @@ async def async_flash_silabs_firmware(
expected_installed_firmware_type: ApplicationType,
bootloader_reset_methods: Sequence[ResetTarget] = (),
progress_callback: Callable[[int, int], None] | None = None,
*,
domain: str = DOMAIN,
) -> FirmwareInfo:
"""Flash firmware to the SiLabs device."""
firmware_info = await guess_firmware_info(hass, device)
_LOGGER.debug("Identified firmware info: %s", firmware_info)
async with async_firmware_update_context(hass, device, domain):
firmware_info = await guess_firmware_info(hass, device)
_LOGGER.debug("Identified firmware info: %s", firmware_info)
fw_image = await hass.async_add_executor_job(parse_firmware_image, fw_data)
fw_image = await hass.async_add_executor_job(parse_firmware_image, fw_data)
flasher = Flasher(
device=device,
probe_methods=(
ApplicationType.GECKO_BOOTLOADER.as_flasher_application_type(),
ApplicationType.EZSP.as_flasher_application_type(),
ApplicationType.SPINEL.as_flasher_application_type(),
ApplicationType.CPC.as_flasher_application_type(),
),
bootloader_reset=tuple(
m.as_flasher_reset_target() for m in bootloader_reset_methods
),
)
async with AsyncExitStack() as stack:
for owner in firmware_info.owners:
await stack.enter_async_context(owner.temporarily_stop(hass))
try:
# Enter the bootloader with indeterminate progress
await flasher.enter_bootloader()
# Flash the firmware, with progress
await flasher.flash_firmware(fw_image, progress_callback=progress_callback)
except Exception as err:
raise HomeAssistantError("Failed to flash firmware") from err
probed_firmware_info = await probe_silabs_firmware_info(
device,
probe_methods=(expected_installed_firmware_type,),
flasher = Flasher(
device=device,
probe_methods=(
ApplicationType.GECKO_BOOTLOADER.as_flasher_application_type(),
ApplicationType.EZSP.as_flasher_application_type(),
ApplicationType.SPINEL.as_flasher_application_type(),
ApplicationType.CPC.as_flasher_application_type(),
),
bootloader_reset=tuple(
m.as_flasher_reset_target() for m in bootloader_reset_methods
),
)
if probed_firmware_info is None:
raise HomeAssistantError("Failed to probe the firmware after flashing")
async with AsyncExitStack() as stack:
for owner in firmware_info.owners:
await stack.enter_async_context(owner.temporarily_stop(hass))
return probed_firmware_info
try:
# Enter the bootloader with indeterminate progress
await flasher.enter_bootloader()
# Flash the firmware, with progress
await flasher.flash_firmware(
fw_image, progress_callback=progress_callback
)
except Exception as err:
raise HomeAssistantError("Failed to flash firmware") from err
probed_firmware_info = await probe_silabs_firmware_info(
device,
probe_methods=(expected_installed_firmware_type,),
)
if probed_firmware_info is None:
raise HomeAssistantError("Failed to probe the firmware after flashing")
return probed_firmware_info
+14 -1
View File
@@ -13,6 +13,7 @@ from zigpy.config import CONF_DATABASE, CONF_DEVICE, CONF_DEVICE_PATH
from zigpy.exceptions import NetworkSettingsInconsistent, TransientConnectionError
from homeassistant.components.homeassistant_hardware.helpers import (
async_is_firmware_update_in_progress,
async_notify_firmware_info,
async_register_firmware_info_provider,
)
@@ -119,6 +120,14 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
def _raise_if_port_in_use(hass: HomeAssistant, device_path: str) -> None:
"""Ensure that the specified serial port is not in use by a firmware update."""
if async_is_firmware_update_in_progress(hass, device_path):
raise ConfigEntryNotReady(
f"Firmware update in progress for device {device_path}"
)
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Set up ZHA.
@@ -152,6 +161,10 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
_LOGGER.debug("Trigger cache: %s", zha_lib_data.device_trigger_cache)
# Check if firmware update is in progress for this device
device_path = config_entry.data[CONF_DEVICE][CONF_DEVICE_PATH]
_raise_if_port_in_use(hass, device_path)
try:
await zha_gateway.async_initialize()
except NetworkSettingsInconsistent as exc:
@@ -168,7 +181,7 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
raise ConfigEntryNotReady from exc
except Exception as exc:
_LOGGER.debug("Failed to set up ZHA", exc_info=exc)
device_path = config_entry.data[CONF_DEVICE][CONF_DEVICE_PATH]
_raise_if_port_in_use(hass, device_path)
if (
not device_path.startswith("socket://")
@@ -22,6 +22,9 @@ from homeassistant.components.homeassistant_hardware.firmware_config_flow import
BaseFirmwareConfigFlow,
BaseFirmwareOptionsFlow,
)
from homeassistant.components.homeassistant_hardware.helpers import (
async_firmware_update_context,
)
from homeassistant.components.homeassistant_hardware.util import (
ApplicationType,
FirmwareInfo,
@@ -302,18 +305,21 @@ def mock_firmware_info(
expected_installed_firmware_type: ApplicationType,
bootloader_reset_methods: Sequence[ResetTarget] = (),
progress_callback: Callable[[int, int], None] | None = None,
*,
domain: str = "homeassistant_hardware",
) -> FirmwareInfo:
await asyncio.sleep(0)
progress_callback(0, 100)
await asyncio.sleep(0)
progress_callback(50, 100)
await asyncio.sleep(0)
progress_callback(100, 100)
async with async_firmware_update_context(hass, device, domain):
await asyncio.sleep(0)
progress_callback(0, 100)
await asyncio.sleep(0)
progress_callback(50, 100)
await asyncio.sleep(0)
progress_callback(100, 100)
if flashed_firmware_info is None:
raise HomeAssistantError("Failed to probe the firmware after flashing")
if flashed_firmware_info is None:
raise HomeAssistantError("Failed to probe the firmware after flashing")
return flashed_firmware_info
return flashed_firmware_info
with (
patch(
@@ -7,9 +7,13 @@ import pytest
from homeassistant.components.homeassistant_hardware.const import DATA_COMPONENT
from homeassistant.components.homeassistant_hardware.helpers import (
async_firmware_update_context,
async_is_firmware_update_in_progress,
async_notify_firmware_info,
async_register_firmware_info_callback,
async_register_firmware_info_provider,
async_register_firmware_update_in_progress,
async_unregister_firmware_update_in_progress,
)
from homeassistant.components.homeassistant_hardware.util import (
ApplicationType,
@@ -183,3 +187,73 @@ async def test_dispatcher_callback_error_handling(
assert callback1.mock_calls == [call(FIRMWARE_INFO_EZSP)]
assert callback2.mock_calls == [call(FIRMWARE_INFO_EZSP)]
async def test_firmware_update_tracking(hass: HomeAssistant) -> None:
"""Test firmware update tracking API."""
await async_setup_component(hass, "homeassistant_hardware", {})
device_path = "/dev/ttyUSB0"
assert not async_is_firmware_update_in_progress(hass, device_path)
# Register an update in progress
async_register_firmware_update_in_progress(hass, device_path, "zha")
assert async_is_firmware_update_in_progress(hass, device_path)
with pytest.raises(ValueError, match="Firmware update already in progress"):
async_register_firmware_update_in_progress(hass, device_path, "skyconnect")
assert async_is_firmware_update_in_progress(hass, device_path)
# Unregister the update with correct domain
async_unregister_firmware_update_in_progress(hass, device_path, "zha")
assert not async_is_firmware_update_in_progress(hass, device_path)
# Test unregistering with wrong domain should raise an error
async_register_firmware_update_in_progress(hass, device_path, "zha")
with pytest.raises(ValueError, match="is owned by zha, not skyconnect"):
async_unregister_firmware_update_in_progress(hass, device_path, "skyconnect")
# Still registered to zha
assert async_is_firmware_update_in_progress(hass, device_path)
async_unregister_firmware_update_in_progress(hass, device_path, "zha")
assert not async_is_firmware_update_in_progress(hass, device_path)
async def test_firmware_update_context_manager(hass: HomeAssistant) -> None:
"""Test firmware update progress context manager."""
await async_setup_component(hass, "homeassistant_hardware", {})
device_path = "/dev/ttyUSB0"
# Initially no updates in progress
assert not async_is_firmware_update_in_progress(hass, device_path)
# Test successful completion
async with async_firmware_update_context(hass, device_path, "zha"):
assert async_is_firmware_update_in_progress(hass, device_path)
# Should be cleaned up after context
assert not async_is_firmware_update_in_progress(hass, device_path)
# Test exception handling
with pytest.raises(ValueError, match="test error"): # noqa: PT012
async with async_firmware_update_context(hass, device_path, "zha"):
assert async_is_firmware_update_in_progress(hass, device_path)
raise ValueError("test error")
# Should still be cleaned up after exception
assert not async_is_firmware_update_in_progress(hass, device_path)
# Test concurrent context manager attempts should fail
async with async_firmware_update_context(hass, device_path, "zha"):
assert async_is_firmware_update_in_progress(hass, device_path)
# Second context manager should fail to register
with pytest.raises(ValueError, match="Firmware update already in progress"):
async with async_firmware_update_context(hass, device_path, "skyconnect"):
pytest.fail("We should not enter this context manager")
# Should be cleaned up after first context
assert not async_is_firmware_update_in_progress(hass, device_path)
@@ -364,6 +364,8 @@ async def test_update_entity_installation(
expected_installed_firmware_type: ApplicationType,
bootloader_reset_methods: Sequence[ResetTarget] = (),
progress_callback: Callable[[int, int], None] | None = None,
*,
domain: str = "homeassistant_hardware",
) -> FirmwareInfo:
await asyncio.sleep(0)
progress_callback(0, 100)
@@ -537,6 +537,8 @@ async def test_probe_silabs_firmware_type(
async def test_async_flash_silabs_firmware(hass: HomeAssistant) -> None:
"""Test async_flash_silabs_firmware."""
await async_setup_component(hass, "homeassistant_hardware", {})
owner1 = create_mock_owner()
owner2 = create_mock_owner()
@@ -625,6 +627,8 @@ async def test_async_flash_silabs_firmware(hass: HomeAssistant) -> None:
async def test_async_flash_silabs_firmware_flash_failure(hass: HomeAssistant) -> None:
"""Test async_flash_silabs_firmware flash failure."""
await async_setup_component(hass, "homeassistant_hardware", {})
owner1 = create_mock_owner()
owner2 = create_mock_owner()
@@ -679,6 +683,8 @@ async def test_async_flash_silabs_firmware_flash_failure(hass: HomeAssistant) ->
async def test_async_flash_silabs_firmware_probe_failure(hass: HomeAssistant) -> None:
"""Test async_flash_silabs_firmware probe failure."""
await async_setup_component(hass, "homeassistant_hardware", {})
owner1 = create_mock_owner()
owner2 = create_mock_owner()
+70
View File
@@ -12,6 +12,10 @@ from zigpy.config import CONF_DEVICE, CONF_DEVICE_PATH
from zigpy.device import Device
from zigpy.exceptions import TransientConnectionError
from homeassistant.components.homeassistant_hardware.helpers import (
async_is_firmware_update_in_progress,
async_register_firmware_update_in_progress,
)
from homeassistant.components.zha.const import (
CONF_BAUDRATE,
CONF_FLOW_CONTROL,
@@ -20,6 +24,7 @@ from homeassistant.components.zha.const import (
DOMAIN,
)
from homeassistant.components.zha.helpers import get_zha_data, get_zha_gateway
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP,
MAJOR_VERSION,
@@ -311,3 +316,68 @@ async def test_timezone_update(
assert hass.config.time_zone == "America/New_York"
assert gateway.config.local_timezone == zoneinfo.ZoneInfo("America/New_York")
async def test_setup_no_firmware_update_in_progress(
hass: HomeAssistant,
config_entry: MockConfigEntry,
mock_zigpy_connect: ControllerApplication,
) -> None:
"""Test that ZHA setup proceeds normally when no firmware update is in progress."""
await async_setup_component(hass, "homeassistant_hardware", {})
config_entry.add_to_hass(hass)
device_path = config_entry.data[CONF_DEVICE][CONF_DEVICE_PATH]
assert not async_is_firmware_update_in_progress(hass, device_path)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.LOADED
async def test_setup_firmware_update_in_progress(
hass: HomeAssistant,
config_entry: MockConfigEntry,
) -> None:
"""Test that ZHA setup is blocked when firmware update is in progress."""
await async_setup_component(hass, "homeassistant_hardware", {})
config_entry.add_to_hass(hass)
device_path = config_entry.data[CONF_DEVICE][CONF_DEVICE_PATH]
async_register_firmware_update_in_progress(hass, device_path, "skyconnect")
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.SETUP_RETRY
async def test_setup_firmware_update_in_progress_prevents_silabs_warning(
hass: HomeAssistant,
config_entry: MockConfigEntry,
mock_zigpy_connect: ControllerApplication,
) -> None:
"""Test firmware update in progress prevents silabs firmware warning on setup failure."""
await async_setup_component(hass, "homeassistant_hardware", {})
config_entry.add_to_hass(hass)
device_path = config_entry.data[CONF_DEVICE][CONF_DEVICE_PATH]
async_register_firmware_update_in_progress(hass, device_path, "skyconnect")
# Make ZHA setup fail
with (
patch.object(
mock_zigpy_connect,
"startup",
side_effect=Exception("Setup failed"),
),
patch(
"homeassistant.components.zha.repairs.wrong_silabs_firmware.warn_on_wrong_silabs_firmware"
) as mock_check_firmware,
):
await hass.config_entries.async_setup(config_entry.entry_id)
# ZHA will try to reload again
assert config_entry.state is ConfigEntryState.SETUP_RETRY
# But it did not try to check if the wrong firmware is installed
assert mock_check_firmware.call_count == 0