MySensors: Improve test coverage

This commit is contained in:
functionpointer
2021-01-27 01:28:31 +01:00
parent 5af4cf97f0
commit 2f217db1f1
2 changed files with 50 additions and 61 deletions

View File

@@ -1,11 +1,7 @@
"""Config flow for MySensors.""" """Config flow for MySensors."""
import asyncio
import logging import logging
from random import randint
from typing import Dict, Optional from typing import Dict, Optional
import async_timeout
from mysensors import BaseAsyncGateway
import voluptuous as vol import voluptuous as vol
from homeassistant.components.mysensors import ( from homeassistant.components.mysensors import (
@@ -18,7 +14,6 @@ import homeassistant.helpers.config_validation as cv
from . import CONF_RETAIN, CONF_VERSION, DEFAULT_VERSION from . import CONF_RETAIN, CONF_VERSION, DEFAULT_VERSION
from ... import config_entries from ... import config_entries
from ...helpers.typing import HomeAssistantType
from ..mqtt import valid_publish_topic, valid_subscribe_topic from ..mqtt import valid_publish_topic, valid_subscribe_topic
from .const import ( from .const import (
CONF_BAUD_RATE, CONF_BAUD_RATE,
@@ -34,55 +29,11 @@ from .const import (
CONF_TOPIC_OUT_PREFIX, CONF_TOPIC_OUT_PREFIX,
DOMAIN, DOMAIN,
) )
from .gateway import MQTT_COMPONENT, _get_gateway, is_serial_port, is_socket_address from .gateway import MQTT_COMPONENT, is_serial_port, is_socket_address, try_connect
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def try_connect(hass: HomeAssistantType, user_input: Dict[str, str]) -> bool:
"""Try to connect to a gateway and report if it worked."""
if user_input[CONF_DEVICE] == MQTT_COMPONENT:
return True # dont validate mqtt. mqtt gateways dont send ready messages :(
user_input_copy = user_input.copy()
try:
gateway: Optional[BaseAsyncGateway] = await _get_gateway(
hass, user_input_copy, str(randint(0, 10 ** 6)), persistence=False
)
if gateway is None:
return False
else:
gateway_ready = asyncio.Future()
def gateway_ready_callback(msg):
msg_type = msg.gateway.const.MessageType(msg.type)
_LOGGER.debug("Received MySensors msg type %s: %s", msg_type.name, msg)
if msg_type.name != "internal":
return
internal = msg.gateway.const.Internal(msg.sub_type)
if internal.name != "I_GATEWAY_READY":
return
_LOGGER.debug("Received gateway ready")
gateway_ready.set_result(True)
gateway.event_callback = gateway_ready_callback
connect_task = None
try:
connect_task = asyncio.create_task(gateway.start())
with async_timeout.timeout(5):
await gateway_ready
return True
except asyncio.TimeoutError:
_LOGGER.info("Try gateway connect failed with timeout")
return False
finally:
if connect_task is not None and not connect_task.done():
connect_task.cancel()
asyncio.create_task(gateway.stop())
except OSError as err:
_LOGGER.info("Try gateway connect failed with exception", exc_info=err)
return False
def _get_schema_common() -> dict: def _get_schema_common() -> dict:
"""Create a schema with options common to all gateway types.""" """Create a schema with options common to all gateway types."""
schema = { schema = {
@@ -96,11 +47,9 @@ def _get_schema_common() -> dict:
return schema return schema
def _validate_version(version: Optional[str]) -> Dict[str, str]: def _validate_version(version: str) -> Dict[str, str]:
"""Validate a version string from the user.""" """Validate a version string from the user."""
errors = {CONF_VERSION: "invalid_version"} errors = {CONF_VERSION: "invalid_version"}
if version is None:
return errors
version_parts = version.split(".") version_parts = version.split(".")
if len(version_parts) != 2: if len(version_parts) != 2:
return errors return errors
@@ -181,14 +130,9 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Create a config entry for a tcp gateway.""" """Create a config entry for a tcp gateway."""
errors = {} errors = {}
if user_input is not None: if user_input is not None:
try: if CONF_TCP_PORT in user_input:
port = int(user_input.get(CONF_TCP_PORT, "")) port: int = user_input[CONF_TCP_PORT]
except ValueError: if not (0 < port <= 65535):
errors[CONF_TCP_PORT] = "not_a_number"
except TypeError: # None, dont complain and use default
pass
else:
if port > 65535 or port < 1:
errors[CONF_TCP_PORT] = "port_out_of_range" errors[CONF_TCP_PORT] = "port_out_of_range"
errors.update( errors.update(

View File

@@ -2,6 +2,7 @@
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
import logging import logging
from random import randint
import socket import socket
import sys import sys
from typing import Any, Callable, Coroutine, Dict, Optional, Union from typing import Any, Callable, Coroutine, Dict, Optional, Union
@@ -59,6 +60,50 @@ def is_socket_address(value):
raise vol.Invalid("Device is not a valid domain name or ip address") from err raise vol.Invalid("Device is not a valid domain name or ip address") from err
async def try_connect(hass: HomeAssistantType, user_input: Dict[str, str]) -> bool:
"""Try to connect to a gateway and report if it worked."""
if user_input[CONF_DEVICE] == MQTT_COMPONENT:
return True # dont validate mqtt. mqtt gateways dont send ready messages :(
user_input_copy = user_input.copy()
try:
gateway: Optional[BaseAsyncGateway] = await _get_gateway(
hass, user_input_copy, str(randint(0, 10 ** 6)), persistence=False
)
if gateway is None:
return False
else:
gateway_ready = asyncio.Future()
def gateway_ready_callback(msg):
msg_type = msg.gateway.const.MessageType(msg.type)
_LOGGER.debug("Received MySensors msg type %s: %s", msg_type.name, msg)
if msg_type.name != "internal":
return
internal = msg.gateway.const.Internal(msg.sub_type)
if internal.name != "I_GATEWAY_READY":
return
_LOGGER.debug("Received gateway ready")
gateway_ready.set_result(True)
gateway.event_callback = gateway_ready_callback
connect_task = None
try:
connect_task = asyncio.create_task(gateway.start())
with async_timeout.timeout(5):
await gateway_ready
return True
except asyncio.TimeoutError:
_LOGGER.info("Try gateway connect failed with timeout")
return False
finally:
if connect_task is not None and not connect_task.done():
connect_task.cancel()
asyncio.create_task(gateway.stop())
except OSError as err:
_LOGGER.info("Try gateway connect failed with exception", exc_info=err)
return False
def get_mysensors_gateway( def get_mysensors_gateway(
hass: HomeAssistantType, gateway_id: GatewayId hass: HomeAssistantType, gateway_id: GatewayId
) -> Optional[BaseAsyncGateway]: ) -> Optional[BaseAsyncGateway]: