Avoid generating matchers that will never be used in MQTT (#118068)

This commit is contained in:
J. Nick Koston
2024-05-24 14:04:03 -10:00
committed by GitHub
parent fa1ef8b0cf
commit 65a702761b

View File

@ -238,12 +238,13 @@ def subscribe(
return remove return remove
@dataclass(frozen=True) @dataclass(slots=True, frozen=True)
class Subscription: class Subscription:
"""Class to hold data about an active subscription.""" """Class to hold data about an active subscription."""
topic: str topic: str
matcher: Any is_simple_match: bool
complex_matcher: Callable[[str], bool] | None
job: HassJob[[ReceiveMessage], Coroutine[Any, Any, None] | None] job: HassJob[[ReceiveMessage], Coroutine[Any, Any, None] | None]
qos: int = 0 qos: int = 0
encoding: str | None = "utf-8" encoding: str | None = "utf-8"
@ -312,11 +313,6 @@ class MqttClientSetup:
return self._client return self._client
def _is_simple_match(topic: str) -> bool:
"""Return if a topic is a simple match."""
return not ("+" in topic or "#" in topic)
class EnsureJobAfterCooldown: class EnsureJobAfterCooldown:
"""Ensure a cool down period before executing a job. """Ensure a cool down period before executing a job.
@ -788,7 +784,7 @@ class MQTT:
The caller is responsible clearing the cache of _matching_subscriptions. The caller is responsible clearing the cache of _matching_subscriptions.
""" """
if _is_simple_match(subscription.topic): if subscription.is_simple_match:
self._simple_subscriptions.setdefault(subscription.topic, []).append( self._simple_subscriptions.setdefault(subscription.topic, []).append(
subscription subscription
) )
@ -805,7 +801,7 @@ class MQTT:
""" """
topic = subscription.topic topic = subscription.topic
try: try:
if _is_simple_match(topic): if subscription.is_simple_match:
simple_subscriptions = self._simple_subscriptions simple_subscriptions = self._simple_subscriptions
simple_subscriptions[topic].remove(subscription) simple_subscriptions[topic].remove(subscription)
if not simple_subscriptions[topic]: if not simple_subscriptions[topic]:
@ -846,8 +842,11 @@ class MQTT:
if not isinstance(topic, str): if not isinstance(topic, str):
raise HomeAssistantError("Topic needs to be a string!") raise HomeAssistantError("Topic needs to be a string!")
is_simple_match = not ("+" in topic or "#" in topic)
matcher = None if is_simple_match else _matcher_for_topic(topic)
subscription = Subscription( subscription = Subscription(
topic, _matcher_for_topic(topic), HassJob(msg_callback), qos, encoding topic, is_simple_match, matcher, HassJob(msg_callback), qos, encoding
) )
self._async_track_subscription(subscription) self._async_track_subscription(subscription)
self._matching_subscriptions.cache_clear() self._matching_subscriptions.cache_clear()
@ -1053,7 +1052,9 @@ class MQTT:
subscriptions.extend( subscriptions.extend(
subscription subscription
for subscription in self._wildcard_subscriptions for subscription in self._wildcard_subscriptions
if subscription.matcher(topic) # mypy doesn't know that complex_matcher is always set when
# is_simple_match is False
if subscription.complex_matcher(topic) # type: ignore[misc]
) )
return subscriptions return subscriptions
@ -1241,7 +1242,7 @@ def _raise_on_error(result_code: int) -> None:
raise HomeAssistantError(f"Error talking to MQTT: {message}") raise HomeAssistantError(f"Error talking to MQTT: {message}")
def _matcher_for_topic(subscription: str) -> Any: def _matcher_for_topic(subscription: str) -> Callable[[str], bool]:
# pylint: disable-next=import-outside-toplevel # pylint: disable-next=import-outside-toplevel
from paho.mqtt.matcher import MQTTMatcher from paho.mqtt.matcher import MQTTMatcher