Make mqtt internal subscription a normal function (#118092)

Co-authored-by: Jan Bouwhuis <jbouwh@users.noreply.github.com>
This commit is contained in:
J. Nick Koston
2024-05-25 11:34:24 -10:00
committed by GitHub
parent ecd48cc447
commit 9be829ba1f
30 changed files with 140 additions and 83 deletions

View File

@ -191,13 +191,25 @@ async def async_subscribe(
Call the return value to unsubscribe.
"""
if not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
translation_key="mqtt_not_setup_cannot_subscribe",
translation_domain=DOMAIN,
translation_placeholders={"topic": topic},
)
return async_subscribe_internal(hass, topic, msg_callback, qos, encoding)
@callback
def async_subscribe_internal(
hass: HomeAssistant,
topic: str,
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING,
) -> CALLBACK_TYPE:
"""Subscribe to an MQTT topic.
This function is internal to the MQTT integration
and may change at any time. It should not be considered
a stable API.
Call the return value to unsubscribe.
"""
try:
mqtt_data = hass.data[DATA_MQTT]
except KeyError as exc:
@ -208,12 +220,15 @@ async def async_subscribe(
translation_domain=DOMAIN,
translation_placeholders={"topic": topic},
) from exc
return await mqtt_data.client.async_subscribe(
topic,
msg_callback,
qos,
encoding,
)
client = mqtt_data.client
if not client.connected and not mqtt_config_entry_enabled(hass):
raise HomeAssistantError(
f"Cannot subscribe to topic '{topic}', MQTT is not enabled",
translation_key="mqtt_not_setup_cannot_subscribe",
translation_domain=DOMAIN,
translation_placeholders={"topic": topic},
)
return client.async_subscribe(topic, msg_callback, qos, encoding)
@bind_hass
@ -845,17 +860,15 @@ class MQTT:
f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe]
)
async def async_subscribe(
@callback
def async_subscribe(
self,
topic: str,
msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None],
qos: int,
encoding: str | None = None,
) -> Callable[[], None]:
"""Set up a subscription to a topic with the provided qos.
This method is a coroutine.
"""
"""Set up a subscription to a topic with the provided qos."""
if not isinstance(topic, str):
raise HomeAssistantError("Topic needs to be a string!")
@ -881,18 +894,18 @@ class MQTT:
if self.connected:
self._async_queue_subscriptions(((topic, qos),))
@callback
def async_remove() -> None:
"""Remove subscription."""
self._async_untrack_subscription(subscription)
self._matching_subscriptions.cache_clear()
if subscription in self._retained_topics:
del self._retained_topics[subscription]
# Only unsubscribe if currently connected
if self.connected:
self._async_unsubscribe(topic)
return partial(self._async_remove, subscription)
return async_remove
@callback
def _async_remove(self, subscription: Subscription) -> None:
"""Remove subscription."""
self._async_untrack_subscription(subscription)
self._matching_subscriptions.cache_clear()
if subscription in self._retained_topics:
del self._retained_topics[subscription]
# Only unsubscribe if currently connected
if self.connected:
self._async_unsubscribe(subscription.topic)
@callback
def _async_unsubscribe(self, topic: str) -> None: