MySensors: Remove union from _get_gateway and remove id from try_connect

This commit is contained in:
functionpointer
2021-01-27 23:10:34 +01:00
parent d5ecfa08f4
commit 43ca44bae3
4 changed files with 82 additions and 69 deletions

View File

@@ -171,7 +171,7 @@ async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry) -> bool
gateway = await setup_gateway(hass, entry)
if not gateway:
_LOGGER.error("gateway setup failed")
_LOGGER.error("Gateway setup failed for %s", entry.data)
return False
if DOMAIN not in hass.data:

View File

@@ -2,10 +2,9 @@
import asyncio
from collections import defaultdict
import logging
from random import randint
import socket
import sys
from typing import Any, Callable, Coroutine, Dict, Optional, Union
from typing import Any, Callable, Coroutine, Dict, Optional
import async_timeout
from mysensors import BaseAsyncGateway, Message, Sensor, mysensors
@@ -64,10 +63,19 @@ async def try_connect(hass: HomeAssistantType, user_input: Dict[str, str]) -> bo
"""Try to connect to a gateway and report if it worked."""
if user_input[CONF_DEVICE] == MQTT_COMPONENT:
return True # dont validate mqtt. mqtt gateways dont send ready messages :(
user_input_copy = user_input.copy()
try:
gateway: Optional[BaseAsyncGateway] = await _get_gateway(
hass, user_input_copy, str(randint(0, 10 ** 6)), persistence=False
hass,
gateway_id=None,
device=user_input[CONF_DEVICE],
version=user_input[CONF_VERSION],
persistence_file=user_input.get(CONF_PERSISTENCE_FILE),
baud_rate=user_input.get(CONF_BAUD_RATE),
tcp_port=user_input.get(CONF_TCP_PORT),
topic_in_prefix=None,
topic_out_prefix=None,
retain=False,
persistence=False,
)
if gateway is None:
return False
@@ -116,46 +124,48 @@ def get_mysensors_gateway(
async def setup_gateway(
hass: HomeAssistantType, entry: ConfigEntry
) -> Optional[BaseAsyncGateway]:
"""Set up all gateways."""
"""Set up the Gateway for the given ConfigEntry."""
ready_gateway = await _get_gateway(hass, entry)
ready_gateway = await _get_gateway(
hass,
gateway_id=entry.entry_id,
device=entry.data[CONF_DEVICE],
version=entry.data[CONF_VERSION],
persistence_file=entry.data.get(CONF_PERSISTENCE_FILE),
baud_rate=entry.data.get(CONF_BAUD_RATE),
tcp_port=entry.data.get(CONF_TCP_PORT),
topic_in_prefix=entry.data.get(CONF_TOPIC_IN_PREFIX),
topic_out_prefix=entry.data.get(CONF_TOPIC_OUT_PREFIX),
retain=entry.data.get(CONF_RETAIN, False),
)
return ready_gateway
async def _get_gateway(
hass: HomeAssistantType,
entry: Union[ConfigEntry, Dict[str, Any]],
unique_id: Optional[str] = None,
gateway_id: Optional[GatewayId],
device: str,
version: str,
persistence_file: Optional[str] = None,
baud_rate: Optional[int] = None,
tcp_port: Optional[int] = None,
topic_in_prefix: Optional[str] = None,
topic_out_prefix: Optional[str] = None,
retain: bool = False,
persistence: bool = True, # old persistence option has been deprecated. kwarg is here so we can run try_connect() without persistence
) -> Optional[BaseAsyncGateway]:
"""Return gateway after setup of the gateway."""
if isinstance(entry, ConfigEntry):
data: Dict[str, Any] = entry.data
unique_id = entry.entry_id
else:
data: Dict[str, Any] = entry
if unique_id is None:
raise ValueError(
"no unique id! either give configEntry for auto-extraction or explicitly give one"
)
persistence_file = data.get(CONF_PERSISTENCE_FILE, f"mysensors_{unique_id}.json")
if persistence_file is None and gateway_id is not None:
persistence_file = f"mysensors_{gateway_id}.json"
# interpret relative paths to be in hass config folder. absolute paths will be left as they are
persistence_file = hass.config.path(persistence_file)
version: str = data[CONF_VERSION]
device: str = data[CONF_DEVICE]
baud_rate: Optional[int] = data.get(CONF_BAUD_RATE)
tcp_port: Optional[int] = data.get(CONF_TCP_PORT)
in_prefix: str = data.get(CONF_TOPIC_IN_PREFIX, "")
out_prefix: str = data.get(CONF_TOPIC_OUT_PREFIX, "")
if device == MQTT_COMPONENT:
# what is the purpose of this?
# if not await async_setup_component(hass, MQTT_COMPONENT, entry):
# return None
mqtt = hass.components.mqtt
retain = data.get(CONF_RETAIN)
def pub_callback(topic, payload, qos, retain):
"""Call MQTT publish function."""
@@ -174,8 +184,8 @@ async def _get_gateway(
gateway = mysensors.AsyncMQTTGateway(
pub_callback,
sub_callback,
in_prefix=in_prefix,
out_prefix=out_prefix,
in_prefix=topic_in_prefix,
out_prefix=topic_out_prefix,
retain=retain,
loop=hass.loop,
event_callback=None,
@@ -210,8 +220,10 @@ async def _get_gateway(
)
except vol.Invalid:
# invalid ip address
_LOGGER.error("Connect failed: Invalid device %s", device)
return None
gateway.event_callback = _gw_callback_factory(hass, entry)
if gateway_id is not None:
gateway.event_callback = _gw_callback_factory(hass, gateway_id)
if persistence:
await gateway.start_persistence()
@@ -244,12 +256,12 @@ async def _discover_persistent_devices(
continue
node: Sensor = gateway.sensors[node_id]
for child in node.children.values(): # child is of type ChildSensor
validated = validate_child(hass_config, gateway, node_id, child)
validated = validate_child(hass_config.entry_id, gateway, node_id, child)
for platform, dev_ids in validated.items():
new_devices[platform].extend(dev_ids)
_LOGGER.debug("discovering persistent devices: %s", new_devices)
for platform, dev_ids in new_devices.items():
discover_mysensors_platform(hass, hass_config, platform, dev_ids)
discover_mysensors_platform(hass, hass_config.entry_id, platform, dev_ids)
if tasks:
await asyncio.wait(tasks)
@@ -300,7 +312,7 @@ async def _gw_start(
def _gw_callback_factory(
hass: HomeAssistantType, hass_config: ConfigEntry
hass: HomeAssistantType, gateway_id: GatewayId
) -> Callable[[Message], None]:
"""Return a new callback for the gateway."""
@@ -315,12 +327,12 @@ def _gw_callback_factory(
msg_type = msg.gateway.const.MessageType(msg.type)
msg_handler: Callable[
[Any, ConfigEntry, Message], Coroutine[None]
[Any, GatewayId, Message], Coroutine[None]
] = HANDLERS.get(msg_type.name)
if msg_handler is None:
return
hass.async_create_task(msg_handler(hass, hass_config, msg))
hass.async_create_task(msg_handler(hass, gateway_id, msg))
return mysensors_callback

View File

@@ -3,13 +3,18 @@ from typing import Dict, List
from mysensors import Message
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.util import decorator
from .const import CHILD_CALLBACK, MYSENSORS_GATEWAY_READY, NODE_CALLBACK, DevId
from .const import (
CHILD_CALLBACK,
MYSENSORS_GATEWAY_READY,
NODE_CALLBACK,
DevId,
GatewayId,
)
from .device import get_mysensors_devices
from .helpers import discover_mysensors_platform, validate_set_msg
@@ -17,62 +22,60 @@ HANDLERS = decorator.Registry()
@HANDLERS.register("set")
async def handle_set(hass, hass_config: ConfigEntry, msg: Message) -> None:
async def handle_set(hass, gateway_id: GatewayId, msg: Message) -> None:
"""Handle a mysensors set message."""
validated = validate_set_msg(hass_config, msg)
_handle_child_update(hass, hass_config, validated)
validated = validate_set_msg(gateway_id, msg)
_handle_child_update(hass, gateway_id, validated)
@HANDLERS.register("internal")
async def handle_internal(hass, hass_config: ConfigEntry, msg: Message) -> None:
async def handle_internal(hass, gateway_id: GatewayId, msg: Message) -> None:
"""Handle a mysensors internal message."""
internal = msg.gateway.const.Internal(msg.sub_type)
handler = HANDLERS.get(internal.name)
if handler is None:
return
await handler(hass, hass_config, msg)
await handler(hass, gateway_id, msg)
@HANDLERS.register("I_BATTERY_LEVEL")
async def handle_battery_level(hass, hass_config: ConfigEntry, msg: Message) -> None:
async def handle_battery_level(hass, gateway_id: GatewayId, msg: Message) -> None:
"""Handle an internal battery level message."""
_handle_node_update(hass, hass_config, msg)
_handle_node_update(hass, gateway_id, msg)
@HANDLERS.register("I_HEARTBEAT_RESPONSE")
async def handle_heartbeat(hass, hass_config: ConfigEntry, msg: Message) -> None:
async def handle_heartbeat(hass, gateway_id: GatewayId, msg: Message) -> None:
"""Handle an heartbeat."""
_handle_node_update(hass, hass_config, msg)
_handle_node_update(hass, gateway_id, msg)
@HANDLERS.register("I_SKETCH_NAME")
async def handle_sketch_name(hass, hass_config: ConfigEntry, msg: Message) -> None:
async def handle_sketch_name(hass, gatway_id: GatewayId, msg: Message) -> None:
"""Handle an internal sketch name message."""
_handle_node_update(hass, hass_config, msg)
_handle_node_update(hass, gatway_id, msg)
@HANDLERS.register("I_SKETCH_VERSION")
async def handle_sketch_version(hass, hass_config: ConfigEntry, msg: Message) -> None:
async def handle_sketch_version(hass, gatway_id: GatewayId, msg: Message) -> None:
"""Handle an internal sketch version message."""
_handle_node_update(hass, hass_config, msg)
_handle_node_update(hass, gatway_id, msg)
@HANDLERS.register("I_GATEWAY_READY")
async def handle_gateway_ready(hass, hass_config: ConfigEntry, msg: Message) -> None:
async def handle_gateway_ready(hass, gatway_id: GatewayId, msg: Message) -> None:
"""Handle an internal gateway ready message.
Set asyncio future result if gateway is ready.
"""
gateway_ready = hass.data.get(MYSENSORS_GATEWAY_READY.format(hass_config.entry_id))
gateway_ready = hass.data.get(MYSENSORS_GATEWAY_READY.format(gatway_id))
if gateway_ready is None or gateway_ready.cancelled():
return
gateway_ready.set_result(True)
@callback
def _handle_child_update(
hass, hass_config: ConfigEntry, validated: Dict[str, List[DevId]]
):
def _handle_child_update(hass, gatway_id: GatewayId, validated: Dict[str, List[DevId]]):
"""Handle a child update."""
signals: List[str] = []
@@ -87,7 +90,7 @@ def _handle_child_update(
else:
new_dev_ids.append(dev_id)
if new_dev_ids:
discover_mysensors_platform(hass, hass_config, platform, new_dev_ids)
discover_mysensors_platform(hass, gatway_id, platform, new_dev_ids)
for signal in set(signals):
# Only one signal per device is needed.
# A device can have multiple platforms, ie multiple schemas.
@@ -95,9 +98,7 @@ def _handle_child_update(
@callback
def _handle_node_update(
hass: HomeAssistantType, hass_config: ConfigEntry, msg: Message
):
def _handle_node_update(hass: HomeAssistantType, gateway_id: GatewayId, msg: Message):
"""Handle a node update."""
signal = NODE_CALLBACK.format(hass_config.entry_id, msg.node_id)
signal = NODE_CALLBACK.format(gateway_id, msg.node_id)
async_dispatcher_send(hass, signal)

View File

@@ -11,10 +11,9 @@ import voluptuous as vol
from homeassistant.const import CONF_NAME
from homeassistant.core import callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.util.decorator import Registry
from ...config_entries import ConfigEntry
from ...helpers.dispatcher import async_dispatcher_send
from .const import (
ATTR_DEVICES,
ATTR_GATEWAY_ID,
@@ -23,6 +22,7 @@ from .const import (
MYSENSORS_DISCOVERY,
TYPE_TO_PLATFORMS,
DevId,
GatewayId,
SensorType,
ValueType,
)
@@ -33,17 +33,17 @@ SCHEMAS = Registry()
@callback
def discover_mysensors_platform(
hass, hass_config: ConfigEntry, platform: str, new_devices: List[DevId]
hass, gateway_id: GatewayId, platform: str, new_devices: List[DevId]
) -> None:
"""Discover a MySensors platform."""
_LOGGER.debug("discovering platform %s with devIds: %s", platform, new_devices)
async_dispatcher_send(
hass,
MYSENSORS_DISCOVERY.format(hass_config.entry_id, platform),
MYSENSORS_DISCOVERY.format(gateway_id, platform),
{
ATTR_DEVICES: new_devices,
CONF_NAME: DOMAIN,
ATTR_GATEWAY_ID: hass_config.entry_id,
ATTR_GATEWAY_ID: gateway_id,
},
)
@@ -130,12 +130,12 @@ def invalid_msg(
)
def validate_set_msg(hass_config: ConfigEntry, msg: Message) -> Dict[str, List[DevId]]:
def validate_set_msg(gateway_id: GatewayId, msg: Message) -> Dict[str, List[DevId]]:
"""Validate a set message."""
if not validate_node(msg.gateway, msg.node_id):
return {}
child = msg.gateway.sensors[msg.node_id].children[msg.child_id]
return validate_child(hass_config, msg.gateway, msg.node_id, child, msg.sub_type)
return validate_child(gateway_id, msg.gateway, msg.node_id, child, msg.sub_type)
def validate_node(gateway: BaseAsyncGateway, node_id: int) -> bool:
@@ -147,7 +147,7 @@ def validate_node(gateway: BaseAsyncGateway, node_id: int) -> bool:
def validate_child(
hass_config: ConfigEntry,
gateway_id: GatewayId,
gateway: BaseAsyncGateway,
node_id: int,
child: ChildSensor,
@@ -195,7 +195,7 @@ def validate_child(
)
continue
dev_id: DevId = (
hass_config.entry_id,
gateway_id,
node_id,
child.id,
set_req[v_name].value,