From 8f12a5f2511ab872880a186f5a4605c8bae80c7d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 26 May 2024 16:11:13 -1000 Subject: [PATCH] fix task storm to load platforms --- homeassistant/components/mqtt/discovery.py | 34 +++++++++++++++++----- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 1609163d9ef..50b5c45f9f7 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -66,6 +66,7 @@ MQTT_DISCOVERY_NEW_COMPONENT = "mqtt_discovery_new_component" MQTT_DISCOVERY_DONE: SignalTypeFormat[Any] = SignalTypeFormat( "mqtt_discovery_done_{}_{}" ) +MQTT_DISCOVERY_NEW_PLATFORMS = "mqtt_discovery_new_platforms" TOPIC_BASE = "~" @@ -268,16 +269,30 @@ async def async_start( # noqa: C901 mqtt_data = hass.data[DATA_MQTT] platform_setup_lock: dict[str, asyncio.Lock] = {} + async def _async_load_platforms(platforms: set[str]) -> None: + """Load a platform.""" + for platform in platforms: + if platform not in platform_setup_lock: + platform_setup_lock[platform] = asyncio.Lock() + + async with platform_setup_lock[platform]: + if platform not in mqtt_data.platforms_loaded: + await async_forward_entry_setup_and_setup_discovery( + hass, config_entry, {platform} + ) + + mqtt_data.reload_dispatchers.append( + async_dispatcher_connect( + hass, MQTT_DISCOVERY_NEW_PLATFORMS, _async_load_platforms + ) + ) + async def _async_component_setup(discovery_payload: MQTTDiscoveryPayload) -> None: """Perform component set up.""" discovery_hash = discovery_payload.discovery_data[ATTR_DISCOVERY_HASH] component, discovery_id = discovery_hash - platform_setup_lock.setdefault(component, asyncio.Lock()) - async with platform_setup_lock[component]: - if component not in mqtt_data.platforms_loaded: - await async_forward_entry_setup_and_setup_discovery( - hass, config_entry, {component} - ) + if component not in mqtt_data.platforms_loaded: + await _async_load_platforms({component}) # Add component message = f"Found new component: {component} {discovery_id}" async_log_discovery_origin_info(message, discovery_payload) @@ -313,7 +328,7 @@ async def async_start( # noqa: C901 return component, node_id, object_id = match.groups() - + platforms: set[str] = set() discovered_components: list[MqttComponentConfig] = [] if component == CONF_DEVICE: # Process device based discovery message @@ -331,6 +346,7 @@ async def async_start( # noqa: C901 ] for component_id, config in component_configs.items(): component = config.pop(CONF_PLATFORM) + platforms.add(component) component_node_id = ( f"{component_id} {node_id}" if node_id else component_id ) @@ -358,6 +374,7 @@ async def async_start( # noqa: C901 ) else: + platforms.add(component) # Process component based discovery message try: discovery_payload = MQTTDiscoveryPayload( @@ -373,6 +390,9 @@ async def async_start( # noqa: C901 MqttComponentConfig(component, object_id, node_id, discovery_payload) ) + if missing_platforms := platforms.difference(mqtt_data.platforms_loaded): + async_dispatcher_send(hass, MQTT_DISCOVERY_NEW_PLATFORMS, missing_platforms) + for component_config in discovered_components: component = component_config.component node_id = component_config.node_id