MySensors: Improve test coverage

This commit is contained in:
functionpointer
2021-01-27 01:28:31 +01:00
parent 5af4cf97f0
commit 2f217db1f1
2 changed files with 50 additions and 61 deletions

View File

@@ -1,11 +1,7 @@
"""Config flow for MySensors."""
import asyncio
import logging
from random import randint
from typing import Dict, Optional
import async_timeout
from mysensors import BaseAsyncGateway
import voluptuous as vol
from homeassistant.components.mysensors import (
@@ -18,7 +14,6 @@ import homeassistant.helpers.config_validation as cv
from . import CONF_RETAIN, CONF_VERSION, DEFAULT_VERSION
from ... import config_entries
from ...helpers.typing import HomeAssistantType
from ..mqtt import valid_publish_topic, valid_subscribe_topic
from .const import (
CONF_BAUD_RATE,
@@ -34,55 +29,11 @@ from .const import (
CONF_TOPIC_OUT_PREFIX,
DOMAIN,
)
from .gateway import MQTT_COMPONENT, _get_gateway, is_serial_port, is_socket_address
from .gateway import MQTT_COMPONENT, is_serial_port, is_socket_address, try_connect
_LOGGER = logging.getLogger(__name__)
async def try_connect(hass: HomeAssistantType, user_input: Dict[str, str]) -> bool:
"""Try to connect to a gateway and report if it worked."""
if user_input[CONF_DEVICE] == MQTT_COMPONENT:
return True # dont validate mqtt. mqtt gateways dont send ready messages :(
user_input_copy = user_input.copy()
try:
gateway: Optional[BaseAsyncGateway] = await _get_gateway(
hass, user_input_copy, str(randint(0, 10 ** 6)), persistence=False
)
if gateway is None:
return False
else:
gateway_ready = asyncio.Future()
def gateway_ready_callback(msg):
msg_type = msg.gateway.const.MessageType(msg.type)
_LOGGER.debug("Received MySensors msg type %s: %s", msg_type.name, msg)
if msg_type.name != "internal":
return
internal = msg.gateway.const.Internal(msg.sub_type)
if internal.name != "I_GATEWAY_READY":
return
_LOGGER.debug("Received gateway ready")
gateway_ready.set_result(True)
gateway.event_callback = gateway_ready_callback
connect_task = None
try:
connect_task = asyncio.create_task(gateway.start())
with async_timeout.timeout(5):
await gateway_ready
return True
except asyncio.TimeoutError:
_LOGGER.info("Try gateway connect failed with timeout")
return False
finally:
if connect_task is not None and not connect_task.done():
connect_task.cancel()
asyncio.create_task(gateway.stop())
except OSError as err:
_LOGGER.info("Try gateway connect failed with exception", exc_info=err)
return False
def _get_schema_common() -> dict:
"""Create a schema with options common to all gateway types."""
schema = {
@@ -96,11 +47,9 @@ def _get_schema_common() -> dict:
return schema
def _validate_version(version: Optional[str]) -> Dict[str, str]:
def _validate_version(version: str) -> Dict[str, str]:
"""Validate a version string from the user."""
errors = {CONF_VERSION: "invalid_version"}
if version is None:
return errors
version_parts = version.split(".")
if len(version_parts) != 2:
return errors
@@ -181,14 +130,9 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Create a config entry for a tcp gateway."""
errors = {}
if user_input is not None:
try:
port = int(user_input.get(CONF_TCP_PORT, ""))
except ValueError:
errors[CONF_TCP_PORT] = "not_a_number"
except TypeError: # None, dont complain and use default
pass
else:
if port > 65535 or port < 1:
if CONF_TCP_PORT in user_input:
port: int = user_input[CONF_TCP_PORT]
if not (0 < port <= 65535):
errors[CONF_TCP_PORT] = "port_out_of_range"
errors.update(

View File

@@ -2,6 +2,7 @@
import asyncio
from collections import defaultdict
import logging
from random import randint
import socket
import sys
from typing import Any, Callable, Coroutine, Dict, Optional, Union
@@ -59,6 +60,50 @@ def is_socket_address(value):
raise vol.Invalid("Device is not a valid domain name or ip address") from err
async def try_connect(hass: HomeAssistantType, user_input: Dict[str, str]) -> bool:
"""Try to connect to a gateway and report if it worked."""
if user_input[CONF_DEVICE] == MQTT_COMPONENT:
return True # dont validate mqtt. mqtt gateways dont send ready messages :(
user_input_copy = user_input.copy()
try:
gateway: Optional[BaseAsyncGateway] = await _get_gateway(
hass, user_input_copy, str(randint(0, 10 ** 6)), persistence=False
)
if gateway is None:
return False
else:
gateway_ready = asyncio.Future()
def gateway_ready_callback(msg):
msg_type = msg.gateway.const.MessageType(msg.type)
_LOGGER.debug("Received MySensors msg type %s: %s", msg_type.name, msg)
if msg_type.name != "internal":
return
internal = msg.gateway.const.Internal(msg.sub_type)
if internal.name != "I_GATEWAY_READY":
return
_LOGGER.debug("Received gateway ready")
gateway_ready.set_result(True)
gateway.event_callback = gateway_ready_callback
connect_task = None
try:
connect_task = asyncio.create_task(gateway.start())
with async_timeout.timeout(5):
await gateway_ready
return True
except asyncio.TimeoutError:
_LOGGER.info("Try gateway connect failed with timeout")
return False
finally:
if connect_task is not None and not connect_task.done():
connect_task.cancel()
asyncio.create_task(gateway.stop())
except OSError as err:
_LOGGER.info("Try gateway connect failed with exception", exc_info=err)
return False
def get_mysensors_gateway(
hass: HomeAssistantType, gateway_id: GatewayId
) -> Optional[BaseAsyncGateway]: