mirror of
https://github.com/home-assistant/core.git
synced 2025-08-05 13:45:12 +02:00
MySensors: Improve test coverage
This commit is contained in:
@@ -1,11 +1,7 @@
|
||||
"""Config flow for MySensors."""
|
||||
import asyncio
|
||||
import logging
|
||||
from random import randint
|
||||
from typing import Dict, Optional
|
||||
|
||||
import async_timeout
|
||||
from mysensors import BaseAsyncGateway
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.mysensors import (
|
||||
@@ -18,7 +14,6 @@ import homeassistant.helpers.config_validation as cv
|
||||
|
||||
from . import CONF_RETAIN, CONF_VERSION, DEFAULT_VERSION
|
||||
from ... import config_entries
|
||||
from ...helpers.typing import HomeAssistantType
|
||||
from ..mqtt import valid_publish_topic, valid_subscribe_topic
|
||||
from .const import (
|
||||
CONF_BAUD_RATE,
|
||||
@@ -34,55 +29,11 @@ from .const import (
|
||||
CONF_TOPIC_OUT_PREFIX,
|
||||
DOMAIN,
|
||||
)
|
||||
from .gateway import MQTT_COMPONENT, _get_gateway, is_serial_port, is_socket_address
|
||||
from .gateway import MQTT_COMPONENT, is_serial_port, is_socket_address, try_connect
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def try_connect(hass: HomeAssistantType, user_input: Dict[str, str]) -> bool:
|
||||
"""Try to connect to a gateway and report if it worked."""
|
||||
if user_input[CONF_DEVICE] == MQTT_COMPONENT:
|
||||
return True # dont validate mqtt. mqtt gateways dont send ready messages :(
|
||||
user_input_copy = user_input.copy()
|
||||
try:
|
||||
gateway: Optional[BaseAsyncGateway] = await _get_gateway(
|
||||
hass, user_input_copy, str(randint(0, 10 ** 6)), persistence=False
|
||||
)
|
||||
if gateway is None:
|
||||
return False
|
||||
else:
|
||||
gateway_ready = asyncio.Future()
|
||||
|
||||
def gateway_ready_callback(msg):
|
||||
msg_type = msg.gateway.const.MessageType(msg.type)
|
||||
_LOGGER.debug("Received MySensors msg type %s: %s", msg_type.name, msg)
|
||||
if msg_type.name != "internal":
|
||||
return
|
||||
internal = msg.gateway.const.Internal(msg.sub_type)
|
||||
if internal.name != "I_GATEWAY_READY":
|
||||
return
|
||||
_LOGGER.debug("Received gateway ready")
|
||||
gateway_ready.set_result(True)
|
||||
|
||||
gateway.event_callback = gateway_ready_callback
|
||||
connect_task = None
|
||||
try:
|
||||
connect_task = asyncio.create_task(gateway.start())
|
||||
with async_timeout.timeout(5):
|
||||
await gateway_ready
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
_LOGGER.info("Try gateway connect failed with timeout")
|
||||
return False
|
||||
finally:
|
||||
if connect_task is not None and not connect_task.done():
|
||||
connect_task.cancel()
|
||||
asyncio.create_task(gateway.stop())
|
||||
except OSError as err:
|
||||
_LOGGER.info("Try gateway connect failed with exception", exc_info=err)
|
||||
return False
|
||||
|
||||
|
||||
def _get_schema_common() -> dict:
|
||||
"""Create a schema with options common to all gateway types."""
|
||||
schema = {
|
||||
@@ -96,11 +47,9 @@ def _get_schema_common() -> dict:
|
||||
return schema
|
||||
|
||||
|
||||
def _validate_version(version: Optional[str]) -> Dict[str, str]:
|
||||
def _validate_version(version: str) -> Dict[str, str]:
|
||||
"""Validate a version string from the user."""
|
||||
errors = {CONF_VERSION: "invalid_version"}
|
||||
if version is None:
|
||||
return errors
|
||||
version_parts = version.split(".")
|
||||
if len(version_parts) != 2:
|
||||
return errors
|
||||
@@ -181,14 +130,9 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Create a config entry for a tcp gateway."""
|
||||
errors = {}
|
||||
if user_input is not None:
|
||||
try:
|
||||
port = int(user_input.get(CONF_TCP_PORT, ""))
|
||||
except ValueError:
|
||||
errors[CONF_TCP_PORT] = "not_a_number"
|
||||
except TypeError: # None, dont complain and use default
|
||||
pass
|
||||
else:
|
||||
if port > 65535 or port < 1:
|
||||
if CONF_TCP_PORT in user_input:
|
||||
port: int = user_input[CONF_TCP_PORT]
|
||||
if not (0 < port <= 65535):
|
||||
errors[CONF_TCP_PORT] = "port_out_of_range"
|
||||
|
||||
errors.update(
|
||||
|
@@ -2,6 +2,7 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
from random import randint
|
||||
import socket
|
||||
import sys
|
||||
from typing import Any, Callable, Coroutine, Dict, Optional, Union
|
||||
@@ -59,6 +60,50 @@ def is_socket_address(value):
|
||||
raise vol.Invalid("Device is not a valid domain name or ip address") from err
|
||||
|
||||
|
||||
async def try_connect(hass: HomeAssistantType, user_input: Dict[str, str]) -> bool:
|
||||
"""Try to connect to a gateway and report if it worked."""
|
||||
if user_input[CONF_DEVICE] == MQTT_COMPONENT:
|
||||
return True # dont validate mqtt. mqtt gateways dont send ready messages :(
|
||||
user_input_copy = user_input.copy()
|
||||
try:
|
||||
gateway: Optional[BaseAsyncGateway] = await _get_gateway(
|
||||
hass, user_input_copy, str(randint(0, 10 ** 6)), persistence=False
|
||||
)
|
||||
if gateway is None:
|
||||
return False
|
||||
else:
|
||||
gateway_ready = asyncio.Future()
|
||||
|
||||
def gateway_ready_callback(msg):
|
||||
msg_type = msg.gateway.const.MessageType(msg.type)
|
||||
_LOGGER.debug("Received MySensors msg type %s: %s", msg_type.name, msg)
|
||||
if msg_type.name != "internal":
|
||||
return
|
||||
internal = msg.gateway.const.Internal(msg.sub_type)
|
||||
if internal.name != "I_GATEWAY_READY":
|
||||
return
|
||||
_LOGGER.debug("Received gateway ready")
|
||||
gateway_ready.set_result(True)
|
||||
|
||||
gateway.event_callback = gateway_ready_callback
|
||||
connect_task = None
|
||||
try:
|
||||
connect_task = asyncio.create_task(gateway.start())
|
||||
with async_timeout.timeout(5):
|
||||
await gateway_ready
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
_LOGGER.info("Try gateway connect failed with timeout")
|
||||
return False
|
||||
finally:
|
||||
if connect_task is not None and not connect_task.done():
|
||||
connect_task.cancel()
|
||||
asyncio.create_task(gateway.stop())
|
||||
except OSError as err:
|
||||
_LOGGER.info("Try gateway connect failed with exception", exc_info=err)
|
||||
return False
|
||||
|
||||
|
||||
def get_mysensors_gateway(
|
||||
hass: HomeAssistantType, gateway_id: GatewayId
|
||||
) -> Optional[BaseAsyncGateway]:
|
||||
|
Reference in New Issue
Block a user