speed up _replace_all_abbreviations

This commit is contained in:
J. Nick Koston
2024-05-21 16:11:09 -10:00
committed by jbouwh
parent d91de41dc1
commit f9c4a736fb

View File

@@ -42,6 +42,11 @@ from .models import DATA_MQTT, MqttOriginInfo, ReceiveMessage
from .schemas import MQTT_ORIGIN_INFO_SCHEMA
from .util import async_forward_entry_setup_and_setup_discovery
ABBREVIATIONS_SET = set(ABBREVIATIONS)
DEVICE_ABBREVIATIONS_SET = set(DEVICE_ABBREVIATIONS)
ORIGIN_ABBREVIATIONS_SET = set(ORIGIN_ABBREVIATIONS)
_LOGGER = logging.getLogger(__name__)
TOPIC_MATCHER = re.compile(
@@ -109,34 +114,42 @@ def async_log_discovery_origin_info(
)
@callback
def _replace_abbreviations(
payload: Any | dict[str, Any],
abbreviations: dict[str, str],
abbreviations_set: set[str],
) -> None:
"""Replace abbreviations in an MQTT discovery payload."""
if not isinstance(payload, dict):
return
for key in abbreviations_set.intersection(payload):
payload[abbreviations[key]] = payload.pop(key)
@callback
def _replace_all_abbreviations(discovery_payload: Any | dict[str, Any]) -> None:
"""Replace all abbreviations in an MQTT discovery payload."""
@callback
def _replace_abbreviations(
discovery_payload: Any | dict[str, Any], abbreviations: dict[str, str]
) -> None:
"""Replace abbreviations in an MQTT discovery payload."""
if not isinstance(discovery_payload, dict):
return
for key in list(discovery_payload):
abbreviated_key = key
key = abbreviations.get(key, key)
discovery_payload[key] = discovery_payload.pop(abbreviated_key)
_replace_abbreviations(discovery_payload, ABBREVIATIONS)
_replace_abbreviations(discovery_payload, ABBREVIATIONS, ABBREVIATIONS_SET)
if CONF_ORIGIN in discovery_payload:
origin_info: dict[str, Any] = discovery_payload[CONF_ORIGIN]
_replace_abbreviations(origin_info, ORIGIN_ABBREVIATIONS)
_replace_abbreviations(
discovery_payload[CONF_ORIGIN],
ORIGIN_ABBREVIATIONS,
ORIGIN_ABBREVIATIONS_SET,
)
if CONF_DEVICE in discovery_payload:
_replace_abbreviations(discovery_payload[CONF_DEVICE], DEVICE_ABBREVIATIONS)
_replace_abbreviations(
discovery_payload[CONF_DEVICE],
DEVICE_ABBREVIATIONS,
DEVICE_ABBREVIATIONS_SET,
)
if CONF_AVAILABILITY in discovery_payload:
for availability_conf in cv.ensure_list(discovery_payload[CONF_AVAILABILITY]):
_replace_abbreviations(availability_conf, ABBREVIATIONS)
_replace_abbreviations(availability_conf, ABBREVIATIONS, ABBREVIATIONS_SET)
@callback