fix task storm to load platforms

This commit is contained in:
J. Nick Koston
2024-05-26 16:11:13 -10:00
parent 2ffac98977
commit 8f12a5f251

View File

@@ -66,6 +66,7 @@ MQTT_DISCOVERY_NEW_COMPONENT = "mqtt_discovery_new_component"
MQTT_DISCOVERY_DONE: SignalTypeFormat[Any] = SignalTypeFormat(
"mqtt_discovery_done_{}_{}"
)
MQTT_DISCOVERY_NEW_PLATFORMS = "mqtt_discovery_new_platforms"
TOPIC_BASE = "~"
@@ -268,16 +269,30 @@ async def async_start( # noqa: C901
mqtt_data = hass.data[DATA_MQTT]
platform_setup_lock: dict[str, asyncio.Lock] = {}
async def _async_load_platforms(platforms: set[str]) -> None:
"""Load a platform."""
for platform in platforms:
if platform not in platform_setup_lock:
platform_setup_lock[platform] = asyncio.Lock()
async with platform_setup_lock[platform]:
if platform not in mqtt_data.platforms_loaded:
await async_forward_entry_setup_and_setup_discovery(
hass, config_entry, {platform}
)
mqtt_data.reload_dispatchers.append(
async_dispatcher_connect(
hass, MQTT_DISCOVERY_NEW_PLATFORMS, _async_load_platforms
)
)
async def _async_component_setup(discovery_payload: MQTTDiscoveryPayload) -> None:
"""Perform component set up."""
discovery_hash = discovery_payload.discovery_data[ATTR_DISCOVERY_HASH]
component, discovery_id = discovery_hash
platform_setup_lock.setdefault(component, asyncio.Lock())
async with platform_setup_lock[component]:
if component not in mqtt_data.platforms_loaded:
await async_forward_entry_setup_and_setup_discovery(
hass, config_entry, {component}
)
if component not in mqtt_data.platforms_loaded:
await _async_load_platforms({component})
# Add component
message = f"Found new component: {component} {discovery_id}"
async_log_discovery_origin_info(message, discovery_payload)
@@ -313,7 +328,7 @@ async def async_start( # noqa: C901
return
component, node_id, object_id = match.groups()
platforms: set[str] = set()
discovered_components: list[MqttComponentConfig] = []
if component == CONF_DEVICE:
# Process device based discovery message
@@ -331,6 +346,7 @@ async def async_start( # noqa: C901
]
for component_id, config in component_configs.items():
component = config.pop(CONF_PLATFORM)
platforms.add(component)
component_node_id = (
f"{component_id} {node_id}" if node_id else component_id
)
@@ -358,6 +374,7 @@ async def async_start( # noqa: C901
)
else:
platforms.add(component)
# Process component based discovery message
try:
discovery_payload = MQTTDiscoveryPayload(
@@ -373,6 +390,9 @@ async def async_start( # noqa: C901
MqttComponentConfig(component, object_id, node_id, discovery_payload)
)
if missing_platforms := platforms.difference(mqtt_data.platforms_loaded):
async_dispatcher_send(hass, MQTT_DISCOVERY_NEW_PLATFORMS, missing_platforms)
for component_config in discovered_components:
component = component_config.component
node_id = component_config.node_id