mirror of
https://github.com/home-assistant/core.git
synced 2025-08-06 06:05:10 +02:00
MySensors: Improve test coverage
This commit is contained in:
@@ -1,11 +1,7 @@
|
|||||||
"""Config flow for MySensors."""
|
"""Config flow for MySensors."""
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from random import randint
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import async_timeout
|
|
||||||
from mysensors import BaseAsyncGateway
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.mysensors import (
|
from homeassistant.components.mysensors import (
|
||||||
@@ -18,7 +14,6 @@ import homeassistant.helpers.config_validation as cv
|
|||||||
|
|
||||||
from . import CONF_RETAIN, CONF_VERSION, DEFAULT_VERSION
|
from . import CONF_RETAIN, CONF_VERSION, DEFAULT_VERSION
|
||||||
from ... import config_entries
|
from ... import config_entries
|
||||||
from ...helpers.typing import HomeAssistantType
|
|
||||||
from ..mqtt import valid_publish_topic, valid_subscribe_topic
|
from ..mqtt import valid_publish_topic, valid_subscribe_topic
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_BAUD_RATE,
|
CONF_BAUD_RATE,
|
||||||
@@ -34,55 +29,11 @@ from .const import (
|
|||||||
CONF_TOPIC_OUT_PREFIX,
|
CONF_TOPIC_OUT_PREFIX,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
)
|
)
|
||||||
from .gateway import MQTT_COMPONENT, _get_gateway, is_serial_port, is_socket_address
|
from .gateway import MQTT_COMPONENT, is_serial_port, is_socket_address, try_connect
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def try_connect(hass: HomeAssistantType, user_input: Dict[str, str]) -> bool:
|
|
||||||
"""Try to connect to a gateway and report if it worked."""
|
|
||||||
if user_input[CONF_DEVICE] == MQTT_COMPONENT:
|
|
||||||
return True # dont validate mqtt. mqtt gateways dont send ready messages :(
|
|
||||||
user_input_copy = user_input.copy()
|
|
||||||
try:
|
|
||||||
gateway: Optional[BaseAsyncGateway] = await _get_gateway(
|
|
||||||
hass, user_input_copy, str(randint(0, 10 ** 6)), persistence=False
|
|
||||||
)
|
|
||||||
if gateway is None:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
gateway_ready = asyncio.Future()
|
|
||||||
|
|
||||||
def gateway_ready_callback(msg):
|
|
||||||
msg_type = msg.gateway.const.MessageType(msg.type)
|
|
||||||
_LOGGER.debug("Received MySensors msg type %s: %s", msg_type.name, msg)
|
|
||||||
if msg_type.name != "internal":
|
|
||||||
return
|
|
||||||
internal = msg.gateway.const.Internal(msg.sub_type)
|
|
||||||
if internal.name != "I_GATEWAY_READY":
|
|
||||||
return
|
|
||||||
_LOGGER.debug("Received gateway ready")
|
|
||||||
gateway_ready.set_result(True)
|
|
||||||
|
|
||||||
gateway.event_callback = gateway_ready_callback
|
|
||||||
connect_task = None
|
|
||||||
try:
|
|
||||||
connect_task = asyncio.create_task(gateway.start())
|
|
||||||
with async_timeout.timeout(5):
|
|
||||||
await gateway_ready
|
|
||||||
return True
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
_LOGGER.info("Try gateway connect failed with timeout")
|
|
||||||
return False
|
|
||||||
finally:
|
|
||||||
if connect_task is not None and not connect_task.done():
|
|
||||||
connect_task.cancel()
|
|
||||||
asyncio.create_task(gateway.stop())
|
|
||||||
except OSError as err:
|
|
||||||
_LOGGER.info("Try gateway connect failed with exception", exc_info=err)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _get_schema_common() -> dict:
|
def _get_schema_common() -> dict:
|
||||||
"""Create a schema with options common to all gateway types."""
|
"""Create a schema with options common to all gateway types."""
|
||||||
schema = {
|
schema = {
|
||||||
@@ -96,11 +47,9 @@ def _get_schema_common() -> dict:
|
|||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|
||||||
def _validate_version(version: Optional[str]) -> Dict[str, str]:
|
def _validate_version(version: str) -> Dict[str, str]:
|
||||||
"""Validate a version string from the user."""
|
"""Validate a version string from the user."""
|
||||||
errors = {CONF_VERSION: "invalid_version"}
|
errors = {CONF_VERSION: "invalid_version"}
|
||||||
if version is None:
|
|
||||||
return errors
|
|
||||||
version_parts = version.split(".")
|
version_parts = version.split(".")
|
||||||
if len(version_parts) != 2:
|
if len(version_parts) != 2:
|
||||||
return errors
|
return errors
|
||||||
@@ -181,14 +130,9 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||||||
"""Create a config entry for a tcp gateway."""
|
"""Create a config entry for a tcp gateway."""
|
||||||
errors = {}
|
errors = {}
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
try:
|
if CONF_TCP_PORT in user_input:
|
||||||
port = int(user_input.get(CONF_TCP_PORT, ""))
|
port: int = user_input[CONF_TCP_PORT]
|
||||||
except ValueError:
|
if not (0 < port <= 65535):
|
||||||
errors[CONF_TCP_PORT] = "not_a_number"
|
|
||||||
except TypeError: # None, dont complain and use default
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
if port > 65535 or port < 1:
|
|
||||||
errors[CONF_TCP_PORT] = "port_out_of_range"
|
errors[CONF_TCP_PORT] = "port_out_of_range"
|
||||||
|
|
||||||
errors.update(
|
errors.update(
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import logging
|
import logging
|
||||||
|
from random import randint
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Callable, Coroutine, Dict, Optional, Union
|
from typing import Any, Callable, Coroutine, Dict, Optional, Union
|
||||||
@@ -59,6 +60,50 @@ def is_socket_address(value):
|
|||||||
raise vol.Invalid("Device is not a valid domain name or ip address") from err
|
raise vol.Invalid("Device is not a valid domain name or ip address") from err
|
||||||
|
|
||||||
|
|
||||||
|
async def try_connect(hass: HomeAssistantType, user_input: Dict[str, str]) -> bool:
|
||||||
|
"""Try to connect to a gateway and report if it worked."""
|
||||||
|
if user_input[CONF_DEVICE] == MQTT_COMPONENT:
|
||||||
|
return True # dont validate mqtt. mqtt gateways dont send ready messages :(
|
||||||
|
user_input_copy = user_input.copy()
|
||||||
|
try:
|
||||||
|
gateway: Optional[BaseAsyncGateway] = await _get_gateway(
|
||||||
|
hass, user_input_copy, str(randint(0, 10 ** 6)), persistence=False
|
||||||
|
)
|
||||||
|
if gateway is None:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
gateway_ready = asyncio.Future()
|
||||||
|
|
||||||
|
def gateway_ready_callback(msg):
|
||||||
|
msg_type = msg.gateway.const.MessageType(msg.type)
|
||||||
|
_LOGGER.debug("Received MySensors msg type %s: %s", msg_type.name, msg)
|
||||||
|
if msg_type.name != "internal":
|
||||||
|
return
|
||||||
|
internal = msg.gateway.const.Internal(msg.sub_type)
|
||||||
|
if internal.name != "I_GATEWAY_READY":
|
||||||
|
return
|
||||||
|
_LOGGER.debug("Received gateway ready")
|
||||||
|
gateway_ready.set_result(True)
|
||||||
|
|
||||||
|
gateway.event_callback = gateway_ready_callback
|
||||||
|
connect_task = None
|
||||||
|
try:
|
||||||
|
connect_task = asyncio.create_task(gateway.start())
|
||||||
|
with async_timeout.timeout(5):
|
||||||
|
await gateway_ready
|
||||||
|
return True
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
_LOGGER.info("Try gateway connect failed with timeout")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
if connect_task is not None and not connect_task.done():
|
||||||
|
connect_task.cancel()
|
||||||
|
asyncio.create_task(gateway.stop())
|
||||||
|
except OSError as err:
|
||||||
|
_LOGGER.info("Try gateway connect failed with exception", exc_info=err)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_mysensors_gateway(
|
def get_mysensors_gateway(
|
||||||
hass: HomeAssistantType, gateway_id: GatewayId
|
hass: HomeAssistantType, gateway_id: GatewayId
|
||||||
) -> Optional[BaseAsyncGateway]:
|
) -> Optional[BaseAsyncGateway]:
|
||||||
|
Reference in New Issue
Block a user