MySensors: Remove union from _get_gateway and remove id from try_connect

This commit is contained in:
functionpointer
2021-01-27 23:10:34 +01:00
parent d5ecfa08f4
commit 43ca44bae3
4 changed files with 82 additions and 69 deletions

View File

@@ -171,7 +171,7 @@ async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry) -> bool
gateway = await setup_gateway(hass, entry) gateway = await setup_gateway(hass, entry)
if not gateway: if not gateway:
_LOGGER.error("gateway setup failed") _LOGGER.error("Gateway setup failed for %s", entry.data)
return False return False
if DOMAIN not in hass.data: if DOMAIN not in hass.data:

View File

@@ -2,10 +2,9 @@
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
import logging import logging
from random import randint
import socket import socket
import sys import sys
from typing import Any, Callable, Coroutine, Dict, Optional, Union from typing import Any, Callable, Coroutine, Dict, Optional
import async_timeout import async_timeout
from mysensors import BaseAsyncGateway, Message, Sensor, mysensors from mysensors import BaseAsyncGateway, Message, Sensor, mysensors
@@ -64,10 +63,19 @@ async def try_connect(hass: HomeAssistantType, user_input: Dict[str, str]) -> bo
"""Try to connect to a gateway and report if it worked.""" """Try to connect to a gateway and report if it worked."""
if user_input[CONF_DEVICE] == MQTT_COMPONENT: if user_input[CONF_DEVICE] == MQTT_COMPONENT:
return True # dont validate mqtt. mqtt gateways dont send ready messages :( return True # dont validate mqtt. mqtt gateways dont send ready messages :(
user_input_copy = user_input.copy()
try: try:
gateway: Optional[BaseAsyncGateway] = await _get_gateway( gateway: Optional[BaseAsyncGateway] = await _get_gateway(
hass, user_input_copy, str(randint(0, 10 ** 6)), persistence=False hass,
gateway_id=None,
device=user_input[CONF_DEVICE],
version=user_input[CONF_VERSION],
persistence_file=user_input.get(CONF_PERSISTENCE_FILE),
baud_rate=user_input.get(CONF_BAUD_RATE),
tcp_port=user_input.get(CONF_TCP_PORT),
topic_in_prefix=None,
topic_out_prefix=None,
retain=False,
persistence=False,
) )
if gateway is None: if gateway is None:
return False return False
@@ -116,46 +124,48 @@ def get_mysensors_gateway(
async def setup_gateway( async def setup_gateway(
hass: HomeAssistantType, entry: ConfigEntry hass: HomeAssistantType, entry: ConfigEntry
) -> Optional[BaseAsyncGateway]: ) -> Optional[BaseAsyncGateway]:
"""Set up all gateways.""" """Set up the Gateway for the given ConfigEntry."""
ready_gateway = await _get_gateway(hass, entry) ready_gateway = await _get_gateway(
hass,
gateway_id=entry.entry_id,
device=entry.data[CONF_DEVICE],
version=entry.data[CONF_VERSION],
persistence_file=entry.data.get(CONF_PERSISTENCE_FILE),
baud_rate=entry.data.get(CONF_BAUD_RATE),
tcp_port=entry.data.get(CONF_TCP_PORT),
topic_in_prefix=entry.data.get(CONF_TOPIC_IN_PREFIX),
topic_out_prefix=entry.data.get(CONF_TOPIC_OUT_PREFIX),
retain=entry.data.get(CONF_RETAIN, False),
)
return ready_gateway return ready_gateway
async def _get_gateway( async def _get_gateway(
hass: HomeAssistantType, hass: HomeAssistantType,
entry: Union[ConfigEntry, Dict[str, Any]], gateway_id: Optional[GatewayId],
unique_id: Optional[str] = None, device: str,
version: str,
persistence_file: Optional[str] = None,
baud_rate: Optional[int] = None,
tcp_port: Optional[int] = None,
topic_in_prefix: Optional[str] = None,
topic_out_prefix: Optional[str] = None,
retain: bool = False,
persistence: bool = True, # old persistence option has been deprecated. kwarg is here so we can run try_connect() without persistence persistence: bool = True, # old persistence option has been deprecated. kwarg is here so we can run try_connect() without persistence
) -> Optional[BaseAsyncGateway]: ) -> Optional[BaseAsyncGateway]:
"""Return gateway after setup of the gateway.""" """Return gateway after setup of the gateway."""
if isinstance(entry, ConfigEntry): if persistence_file is None and gateway_id is not None:
data: Dict[str, Any] = entry.data persistence_file = f"mysensors_{gateway_id}.json"
unique_id = entry.entry_id
else:
data: Dict[str, Any] = entry
if unique_id is None:
raise ValueError(
"no unique id! either give configEntry for auto-extraction or explicitly give one"
)
persistence_file = data.get(CONF_PERSISTENCE_FILE, f"mysensors_{unique_id}.json")
# interpret relative paths to be in hass config folder. absolute paths will be left as they are # interpret relative paths to be in hass config folder. absolute paths will be left as they are
persistence_file = hass.config.path(persistence_file) persistence_file = hass.config.path(persistence_file)
version: str = data[CONF_VERSION]
device: str = data[CONF_DEVICE]
baud_rate: Optional[int] = data.get(CONF_BAUD_RATE)
tcp_port: Optional[int] = data.get(CONF_TCP_PORT)
in_prefix: str = data.get(CONF_TOPIC_IN_PREFIX, "")
out_prefix: str = data.get(CONF_TOPIC_OUT_PREFIX, "")
if device == MQTT_COMPONENT: if device == MQTT_COMPONENT:
# what is the purpose of this? # what is the purpose of this?
# if not await async_setup_component(hass, MQTT_COMPONENT, entry): # if not await async_setup_component(hass, MQTT_COMPONENT, entry):
# return None # return None
mqtt = hass.components.mqtt mqtt = hass.components.mqtt
retain = data.get(CONF_RETAIN)
def pub_callback(topic, payload, qos, retain): def pub_callback(topic, payload, qos, retain):
"""Call MQTT publish function.""" """Call MQTT publish function."""
@@ -174,8 +184,8 @@ async def _get_gateway(
gateway = mysensors.AsyncMQTTGateway( gateway = mysensors.AsyncMQTTGateway(
pub_callback, pub_callback,
sub_callback, sub_callback,
in_prefix=in_prefix, in_prefix=topic_in_prefix,
out_prefix=out_prefix, out_prefix=topic_out_prefix,
retain=retain, retain=retain,
loop=hass.loop, loop=hass.loop,
event_callback=None, event_callback=None,
@@ -210,8 +220,10 @@ async def _get_gateway(
) )
except vol.Invalid: except vol.Invalid:
# invalid ip address # invalid ip address
_LOGGER.error("Connect failed: Invalid device %s", device)
return None return None
gateway.event_callback = _gw_callback_factory(hass, entry) if gateway_id is not None:
gateway.event_callback = _gw_callback_factory(hass, gateway_id)
if persistence: if persistence:
await gateway.start_persistence() await gateway.start_persistence()
@@ -244,12 +256,12 @@ async def _discover_persistent_devices(
continue continue
node: Sensor = gateway.sensors[node_id] node: Sensor = gateway.sensors[node_id]
for child in node.children.values(): # child is of type ChildSensor for child in node.children.values(): # child is of type ChildSensor
validated = validate_child(hass_config, gateway, node_id, child) validated = validate_child(hass_config.entry_id, gateway, node_id, child)
for platform, dev_ids in validated.items(): for platform, dev_ids in validated.items():
new_devices[platform].extend(dev_ids) new_devices[platform].extend(dev_ids)
_LOGGER.debug("discovering persistent devices: %s", new_devices) _LOGGER.debug("discovering persistent devices: %s", new_devices)
for platform, dev_ids in new_devices.items(): for platform, dev_ids in new_devices.items():
discover_mysensors_platform(hass, hass_config, platform, dev_ids) discover_mysensors_platform(hass, hass_config.entry_id, platform, dev_ids)
if tasks: if tasks:
await asyncio.wait(tasks) await asyncio.wait(tasks)
@@ -300,7 +312,7 @@ async def _gw_start(
def _gw_callback_factory( def _gw_callback_factory(
hass: HomeAssistantType, hass_config: ConfigEntry hass: HomeAssistantType, gateway_id: GatewayId
) -> Callable[[Message], None]: ) -> Callable[[Message], None]:
"""Return a new callback for the gateway.""" """Return a new callback for the gateway."""
@@ -315,12 +327,12 @@ def _gw_callback_factory(
msg_type = msg.gateway.const.MessageType(msg.type) msg_type = msg.gateway.const.MessageType(msg.type)
msg_handler: Callable[ msg_handler: Callable[
[Any, ConfigEntry, Message], Coroutine[None] [Any, GatewayId, Message], Coroutine[None]
] = HANDLERS.get(msg_type.name) ] = HANDLERS.get(msg_type.name)
if msg_handler is None: if msg_handler is None:
return return
hass.async_create_task(msg_handler(hass, hass_config, msg)) hass.async_create_task(msg_handler(hass, gateway_id, msg))
return mysensors_callback return mysensors_callback

View File

@@ -3,13 +3,18 @@ from typing import Dict, List
from mysensors import Message from mysensors import Message
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.util import decorator from homeassistant.util import decorator
from .const import CHILD_CALLBACK, MYSENSORS_GATEWAY_READY, NODE_CALLBACK, DevId from .const import (
CHILD_CALLBACK,
MYSENSORS_GATEWAY_READY,
NODE_CALLBACK,
DevId,
GatewayId,
)
from .device import get_mysensors_devices from .device import get_mysensors_devices
from .helpers import discover_mysensors_platform, validate_set_msg from .helpers import discover_mysensors_platform, validate_set_msg
@@ -17,62 +22,60 @@ HANDLERS = decorator.Registry()
@HANDLERS.register("set") @HANDLERS.register("set")
async def handle_set(hass, hass_config: ConfigEntry, msg: Message) -> None: async def handle_set(hass, gateway_id: GatewayId, msg: Message) -> None:
"""Handle a mysensors set message.""" """Handle a mysensors set message."""
validated = validate_set_msg(hass_config, msg) validated = validate_set_msg(gateway_id, msg)
_handle_child_update(hass, hass_config, validated) _handle_child_update(hass, gateway_id, validated)
@HANDLERS.register("internal") @HANDLERS.register("internal")
async def handle_internal(hass, hass_config: ConfigEntry, msg: Message) -> None: async def handle_internal(hass, gateway_id: GatewayId, msg: Message) -> None:
"""Handle a mysensors internal message.""" """Handle a mysensors internal message."""
internal = msg.gateway.const.Internal(msg.sub_type) internal = msg.gateway.const.Internal(msg.sub_type)
handler = HANDLERS.get(internal.name) handler = HANDLERS.get(internal.name)
if handler is None: if handler is None:
return return
await handler(hass, hass_config, msg) await handler(hass, gateway_id, msg)
@HANDLERS.register("I_BATTERY_LEVEL") @HANDLERS.register("I_BATTERY_LEVEL")
async def handle_battery_level(hass, hass_config: ConfigEntry, msg: Message) -> None: async def handle_battery_level(hass, gateway_id: GatewayId, msg: Message) -> None:
"""Handle an internal battery level message.""" """Handle an internal battery level message."""
_handle_node_update(hass, hass_config, msg) _handle_node_update(hass, gateway_id, msg)
@HANDLERS.register("I_HEARTBEAT_RESPONSE") @HANDLERS.register("I_HEARTBEAT_RESPONSE")
async def handle_heartbeat(hass, hass_config: ConfigEntry, msg: Message) -> None: async def handle_heartbeat(hass, gateway_id: GatewayId, msg: Message) -> None:
"""Handle an heartbeat.""" """Handle an heartbeat."""
_handle_node_update(hass, hass_config, msg) _handle_node_update(hass, gateway_id, msg)
@HANDLERS.register("I_SKETCH_NAME") @HANDLERS.register("I_SKETCH_NAME")
async def handle_sketch_name(hass, hass_config: ConfigEntry, msg: Message) -> None: async def handle_sketch_name(hass, gatway_id: GatewayId, msg: Message) -> None:
"""Handle an internal sketch name message.""" """Handle an internal sketch name message."""
_handle_node_update(hass, hass_config, msg) _handle_node_update(hass, gatway_id, msg)
@HANDLERS.register("I_SKETCH_VERSION") @HANDLERS.register("I_SKETCH_VERSION")
async def handle_sketch_version(hass, hass_config: ConfigEntry, msg: Message) -> None: async def handle_sketch_version(hass, gatway_id: GatewayId, msg: Message) -> None:
"""Handle an internal sketch version message.""" """Handle an internal sketch version message."""
_handle_node_update(hass, hass_config, msg) _handle_node_update(hass, gatway_id, msg)
@HANDLERS.register("I_GATEWAY_READY") @HANDLERS.register("I_GATEWAY_READY")
async def handle_gateway_ready(hass, hass_config: ConfigEntry, msg: Message) -> None: async def handle_gateway_ready(hass, gatway_id: GatewayId, msg: Message) -> None:
"""Handle an internal gateway ready message. """Handle an internal gateway ready message.
Set asyncio future result if gateway is ready. Set asyncio future result if gateway is ready.
""" """
gateway_ready = hass.data.get(MYSENSORS_GATEWAY_READY.format(hass_config.entry_id)) gateway_ready = hass.data.get(MYSENSORS_GATEWAY_READY.format(gatway_id))
if gateway_ready is None or gateway_ready.cancelled(): if gateway_ready is None or gateway_ready.cancelled():
return return
gateway_ready.set_result(True) gateway_ready.set_result(True)
@callback @callback
def _handle_child_update( def _handle_child_update(hass, gatway_id: GatewayId, validated: Dict[str, List[DevId]]):
hass, hass_config: ConfigEntry, validated: Dict[str, List[DevId]]
):
"""Handle a child update.""" """Handle a child update."""
signals: List[str] = [] signals: List[str] = []
@@ -87,7 +90,7 @@ def _handle_child_update(
else: else:
new_dev_ids.append(dev_id) new_dev_ids.append(dev_id)
if new_dev_ids: if new_dev_ids:
discover_mysensors_platform(hass, hass_config, platform, new_dev_ids) discover_mysensors_platform(hass, gatway_id, platform, new_dev_ids)
for signal in set(signals): for signal in set(signals):
# Only one signal per device is needed. # Only one signal per device is needed.
# A device can have multiple platforms, ie multiple schemas. # A device can have multiple platforms, ie multiple schemas.
@@ -95,9 +98,7 @@ def _handle_child_update(
@callback @callback
def _handle_node_update( def _handle_node_update(hass: HomeAssistantType, gateway_id: GatewayId, msg: Message):
hass: HomeAssistantType, hass_config: ConfigEntry, msg: Message
):
"""Handle a node update.""" """Handle a node update."""
signal = NODE_CALLBACK.format(hass_config.entry_id, msg.node_id) signal = NODE_CALLBACK.format(gateway_id, msg.node_id)
async_dispatcher_send(hass, signal) async_dispatcher_send(hass, signal)

View File

@@ -11,10 +11,9 @@ import voluptuous as vol
from homeassistant.const import CONF_NAME from homeassistant.const import CONF_NAME
from homeassistant.core import callback from homeassistant.core import callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
from ...config_entries import ConfigEntry
from ...helpers.dispatcher import async_dispatcher_send
from .const import ( from .const import (
ATTR_DEVICES, ATTR_DEVICES,
ATTR_GATEWAY_ID, ATTR_GATEWAY_ID,
@@ -23,6 +22,7 @@ from .const import (
MYSENSORS_DISCOVERY, MYSENSORS_DISCOVERY,
TYPE_TO_PLATFORMS, TYPE_TO_PLATFORMS,
DevId, DevId,
GatewayId,
SensorType, SensorType,
ValueType, ValueType,
) )
@@ -33,17 +33,17 @@ SCHEMAS = Registry()
@callback @callback
def discover_mysensors_platform( def discover_mysensors_platform(
hass, hass_config: ConfigEntry, platform: str, new_devices: List[DevId] hass, gateway_id: GatewayId, platform: str, new_devices: List[DevId]
) -> None: ) -> None:
"""Discover a MySensors platform.""" """Discover a MySensors platform."""
_LOGGER.debug("discovering platform %s with devIds: %s", platform, new_devices) _LOGGER.debug("discovering platform %s with devIds: %s", platform, new_devices)
async_dispatcher_send( async_dispatcher_send(
hass, hass,
MYSENSORS_DISCOVERY.format(hass_config.entry_id, platform), MYSENSORS_DISCOVERY.format(gateway_id, platform),
{ {
ATTR_DEVICES: new_devices, ATTR_DEVICES: new_devices,
CONF_NAME: DOMAIN, CONF_NAME: DOMAIN,
ATTR_GATEWAY_ID: hass_config.entry_id, ATTR_GATEWAY_ID: gateway_id,
}, },
) )
@@ -130,12 +130,12 @@ def invalid_msg(
) )
def validate_set_msg(hass_config: ConfigEntry, msg: Message) -> Dict[str, List[DevId]]: def validate_set_msg(gateway_id: GatewayId, msg: Message) -> Dict[str, List[DevId]]:
"""Validate a set message.""" """Validate a set message."""
if not validate_node(msg.gateway, msg.node_id): if not validate_node(msg.gateway, msg.node_id):
return {} return {}
child = msg.gateway.sensors[msg.node_id].children[msg.child_id] child = msg.gateway.sensors[msg.node_id].children[msg.child_id]
return validate_child(hass_config, msg.gateway, msg.node_id, child, msg.sub_type) return validate_child(gateway_id, msg.gateway, msg.node_id, child, msg.sub_type)
def validate_node(gateway: BaseAsyncGateway, node_id: int) -> bool: def validate_node(gateway: BaseAsyncGateway, node_id: int) -> bool:
@@ -147,7 +147,7 @@ def validate_node(gateway: BaseAsyncGateway, node_id: int) -> bool:
def validate_child( def validate_child(
hass_config: ConfigEntry, gateway_id: GatewayId,
gateway: BaseAsyncGateway, gateway: BaseAsyncGateway,
node_id: int, node_id: int,
child: ChildSensor, child: ChildSensor,
@@ -195,7 +195,7 @@ def validate_child(
) )
continue continue
dev_id: DevId = ( dev_id: DevId = (
hass_config.entry_id, gateway_id,
node_id, node_id,
child.id, child.id,
set_req[v_name].value, set_req[v_name].value,