From c4da5374aeacac7458994e58f6b7f2c741c8ab47 Mon Sep 17 00:00:00 2001 From: G Johansson Date: Mon, 7 Aug 2023 17:25:02 +0200 Subject: [PATCH] Refactor Trafikverket Train to improve config flow (#97929) * Refactor tvt * review fixes * review comments --- .coveragerc | 1 + .../trafikverket_train/config_flow.py | 154 +++++++++----- .../trafikverket_train/coordinator.py | 28 +-- .../trafikverket_train/strings.json | 6 +- .../components/trafikverket_train/util.py | 25 ++- .../trafikverket_train/test_config_flow.py | 191 +++++++++++++++--- 6 files changed, 290 insertions(+), 115 deletions(-) diff --git a/.coveragerc b/.coveragerc index d895b1adf0a..01e1d0d3b0e 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1341,6 +1341,7 @@ omit = homeassistant/components/trafikverket_train/__init__.py homeassistant/components/trafikverket_train/coordinator.py homeassistant/components/trafikverket_train/sensor.py + homeassistant/components/trafikverket_train/util.py homeassistant/components/trafikverket_weatherstation/__init__.py homeassistant/components/trafikverket_weatherstation/coordinator.py homeassistant/components/trafikverket_weatherstation/sensor.py diff --git a/homeassistant/components/trafikverket_train/config_flow.py b/homeassistant/components/trafikverket_train/config_flow.py index fc23d3b953d..f5000851755 100644 --- a/homeassistant/components/trafikverket_train/config_flow.py +++ b/homeassistant/components/trafikverket_train/config_flow.py @@ -2,18 +2,24 @@ from __future__ import annotations from collections.abc import Mapping +from datetime import datetime +import logging from typing import Any from pytrafikverket import TrafikverketTrain from pytrafikverket.exceptions import ( InvalidAuthentication, + MultipleTrainAnnouncementFound, MultipleTrainStationsFound, + NoTrainAnnouncementFound, NoTrainStationFound, + UnknownError, ) import voluptuous as vol from homeassistant import config_entries from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_WEEKDAY, WEEKDAYS +from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers.aiohttp_client import async_get_clientsession import homeassistant.helpers.config_validation as cv @@ -22,18 +28,21 @@ from homeassistant.helpers.selector import ( SelectSelectorConfig, SelectSelectorMode, TextSelector, + TimeSelector, ) import homeassistant.util.dt as dt_util from .const import CONF_FROM, CONF_TIME, CONF_TO, DOMAIN -from .util import create_unique_id +from .util import create_unique_id, next_departuredate + +_LOGGER = logging.getLogger(__name__) DATA_SCHEMA = vol.Schema( { vol.Required(CONF_API_KEY): TextSelector(), vol.Required(CONF_FROM): TextSelector(), vol.Required(CONF_TO): TextSelector(), - vol.Optional(CONF_TIME): TextSelector(), + vol.Optional(CONF_TIME): TimeSelector(), vol.Required(CONF_WEEKDAY, default=WEEKDAYS): SelectSelector( SelectSelectorConfig( options=WEEKDAYS, @@ -51,6 +60,56 @@ DATA_SCHEMA_REAUTH = vol.Schema( ) +async def validate_input( + hass: HomeAssistant, + api_key: str, + train_from: str, + train_to: str, + train_time: str | None, + weekdays: list[str], +) -> dict[str, str]: + """Validate input from user input.""" + errors: dict[str, str] = {} + + when = dt_util.now() + if train_time: + departure_day = next_departuredate(weekdays) + if _time := dt_util.parse_time(train_time): + when = datetime.combine( + departure_day, + _time, + dt_util.get_time_zone(hass.config.time_zone), + ) + + try: + web_session = async_get_clientsession(hass) + train_api = TrafikverketTrain(web_session, api_key) + from_station = await train_api.async_get_train_station(train_from) + to_station = await train_api.async_get_train_station(train_to) + if train_time: + await train_api.async_get_train_stop(from_station, to_station, when) + else: + await train_api.async_get_next_train_stop(from_station, to_station, when) + except InvalidAuthentication: + errors["base"] = "invalid_auth" + except NoTrainStationFound: + errors["base"] = "invalid_station" + except MultipleTrainStationsFound: + errors["base"] = "more_stations" + except NoTrainAnnouncementFound: + errors["base"] = "no_trains" + except MultipleTrainAnnouncementFound: + errors["base"] = "multiple_trains" + except UnknownError as error: + _LOGGER.error("Unknown error occurred during validation %s", str(error)) + errors["base"] = "cannot_connect" + except Exception as error: # pylint: disable=broad-exception-caught + _LOGGER.error("Unknown exception occurred during validation %s", str(error)) + errors["base"] = "cannot_connect" + + return errors + + class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Handle a config flow for Trafikverket Train integration.""" @@ -58,15 +117,6 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): entry: config_entries.ConfigEntry | None - async def validate_input( - self, api_key: str, train_from: str, train_to: str - ) -> None: - """Validate input from user input.""" - web_session = async_get_clientsession(self.hass) - train_api = TrafikverketTrain(web_session, api_key) - await train_api.async_get_train_station(train_from) - await train_api.async_get_train_station(train_to) - async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: """Handle re-authentication with Trafikverket.""" @@ -83,19 +133,15 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): api_key = user_input[CONF_API_KEY] assert self.entry is not None - try: - await self.validate_input( - api_key, self.entry.data[CONF_FROM], self.entry.data[CONF_TO] - ) - except InvalidAuthentication: - errors["base"] = "invalid_auth" - except NoTrainStationFound: - errors["base"] = "invalid_station" - except MultipleTrainStationsFound: - errors["base"] = "more_stations" - except Exception: # pylint: disable=broad-exception-caught - errors["base"] = "cannot_connect" - else: + errors = await validate_input( + self.hass, + api_key, + self.entry.data[CONF_FROM], + self.entry.data[CONF_TO], + self.entry.data.get(CONF_TIME), + self.entry.data[CONF_WEEKDAY], + ) + if not errors: self.hass.config_entries.async_update_entry( self.entry, data={ @@ -129,40 +175,36 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): if train_time: name = f"{train_from} to {train_to} at {train_time}" - try: - await self.validate_input(api_key, train_from, train_to) - except InvalidAuthentication: - errors["base"] = "invalid_auth" - except NoTrainStationFound: - errors["base"] = "invalid_station" - except MultipleTrainStationsFound: - errors["base"] = "more_stations" - except Exception: # pylint: disable=broad-exception-caught - errors["base"] = "cannot_connect" - else: - if train_time: - if bool(dt_util.parse_time(train_time) is None): - errors["base"] = "invalid_time" - if not errors: - unique_id = create_unique_id( - train_from, train_to, train_time, train_days - ) - await self.async_set_unique_id(unique_id) - self._abort_if_unique_id_configured() - return self.async_create_entry( - title=name, - data={ - CONF_API_KEY: api_key, - CONF_NAME: name, - CONF_FROM: train_from, - CONF_TO: train_to, - CONF_TIME: train_time, - CONF_WEEKDAY: train_days, - }, - ) + errors = await validate_input( + self.hass, + api_key, + train_from, + train_to, + train_time, + train_days, + ) + if not errors: + unique_id = create_unique_id( + train_from, train_to, train_time, train_days + ) + await self.async_set_unique_id(unique_id) + self._abort_if_unique_id_configured() + return self.async_create_entry( + title=name, + data={ + CONF_API_KEY: api_key, + CONF_NAME: name, + CONF_FROM: train_from, + CONF_TO: train_to, + CONF_TIME: train_time, + CONF_WEEKDAY: train_days, + }, + ) return self.async_show_form( step_id="user", - data_schema=DATA_SCHEMA, + data_schema=self.add_suggested_values_to_schema( + DATA_SCHEMA, user_input or {} + ), errors=errors, ) diff --git a/homeassistant/components/trafikverket_train/coordinator.py b/homeassistant/components/trafikverket_train/coordinator.py index 3125fea8e39..fac1c418b09 100644 --- a/homeassistant/components/trafikverket_train/coordinator.py +++ b/homeassistant/components/trafikverket_train/coordinator.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass -from datetime import date, datetime, time, timedelta +from datetime import datetime, time, timedelta import logging from pytrafikverket import TrafikverketTrain @@ -15,7 +15,7 @@ from pytrafikverket.exceptions import ( from pytrafikverket.trafikverket_train import StationInfo, TrainStop from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_API_KEY, CONF_WEEKDAY, WEEKDAYS +from homeassistant.const import CONF_API_KEY, CONF_WEEKDAY from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.helpers.aiohttp_client import async_get_clientsession @@ -23,6 +23,7 @@ from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, Upda from homeassistant.util import dt as dt_util from .const import CONF_TIME, DOMAIN +from .util import next_departuredate @dataclass @@ -44,27 +45,6 @@ _LOGGER = logging.getLogger(__name__) TIME_BETWEEN_UPDATES = timedelta(minutes=5) -def _next_weekday(fromdate: date, weekday: int) -> date: - """Return the date of the next time a specific weekday happen.""" - days_ahead = weekday - fromdate.weekday() - if days_ahead <= 0: - days_ahead += 7 - return fromdate + timedelta(days_ahead) - - -def _next_departuredate(departure: list[str]) -> date: - """Calculate the next departuredate from an array input of short days.""" - today_date = date.today() - today_weekday = date.weekday(today_date) - if WEEKDAYS[today_weekday] in departure: - return today_date - for day in departure: - next_departure = WEEKDAYS.index(day) - if next_departure > today_weekday: - return _next_weekday(today_date, next_departure) - return _next_weekday(today_date, WEEKDAYS.index(departure[0])) - - def _get_as_utc(date_value: datetime | None) -> datetime | None: """Return utc datetime or None.""" if date_value: @@ -110,7 +90,7 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]): when = dt_util.now() state: TrainStop | None = None if self._time: - departure_day = _next_departuredate(self._weekdays) + departure_day = next_departuredate(self._weekdays) when = datetime.combine( departure_day, self._time, diff --git a/homeassistant/components/trafikverket_train/strings.json b/homeassistant/components/trafikverket_train/strings.json index 59431107ae2..aabab0907ab 100644 --- a/homeassistant/components/trafikverket_train/strings.json +++ b/homeassistant/components/trafikverket_train/strings.json @@ -9,7 +9,8 @@ "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]", "invalid_station": "Could not find a station with the specified name", "more_stations": "Found multiple stations with the specified name", - "invalid_time": "Invalid time provided", + "no_trains": "No train found", + "multiple_trains": "Multiple trains found", "incorrect_api_key": "Invalid API key for selected account" }, "step": { @@ -20,6 +21,9 @@ "from": "From station", "time": "Time (optional)", "weekday": "Days" + }, + "data_description": { + "time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure" } }, "reauth_confirm": { diff --git a/homeassistant/components/trafikverket_train/util.py b/homeassistant/components/trafikverket_train/util.py index 6ed672c9e7e..c5553c4a4a7 100644 --- a/homeassistant/components/trafikverket_train/util.py +++ b/homeassistant/components/trafikverket_train/util.py @@ -1,7 +1,9 @@ """Utils for trafikverket_train.""" from __future__ import annotations -from datetime import time +from datetime import date, time, timedelta + +from homeassistant.const import WEEKDAYS def create_unique_id( @@ -13,3 +15,24 @@ def create_unique_id( f"{from_station.casefold().replace(' ', '')}-{to_station.casefold().replace(' ', '')}" f"-{timestr.casefold().replace(' ', '')}-{str(weekdays)}" ) + + +def next_weekday(fromdate: date, weekday: int) -> date: + """Return the date of the next time a specific weekday happen.""" + days_ahead = weekday - fromdate.weekday() + if days_ahead <= 0: + days_ahead += 7 + return fromdate + timedelta(days_ahead) + + +def next_departuredate(departure: list[str]) -> date: + """Calculate the next departuredate from an array input of short days.""" + today_date = date.today() + today_weekday = date.weekday(today_date) + if WEEKDAYS[today_weekday] in departure: + return today_date + for day in departure: + next_departure = WEEKDAYS.index(day) + if next_departure > today_weekday: + return next_weekday(today_date, next_departure) + return next_weekday(today_date, WEEKDAYS.index(departure[0])) diff --git a/tests/components/trafikverket_train/test_config_flow.py b/tests/components/trafikverket_train/test_config_flow.py index 424e1d74162..a3b449755c7 100644 --- a/tests/components/trafikverket_train/test_config_flow.py +++ b/tests/components/trafikverket_train/test_config_flow.py @@ -6,8 +6,11 @@ from unittest.mock import patch import pytest from pytrafikverket.exceptions import ( InvalidAuthentication, + MultipleTrainAnnouncementFound, MultipleTrainStationsFound, + NoTrainAnnouncementFound, NoTrainStationFound, + UnknownError, ) from homeassistant import config_entries @@ -35,11 +38,13 @@ async def test_form(hass: HomeAssistant) -> None: with patch( "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", + ), patch( + "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop", ), patch( "homeassistant.components.trafikverket_train.async_setup_entry", return_value=True, ) as mock_setup_entry: - result2 = await hass.config_entries.flow.async_configure( + result = await hass.config_entries.flow.async_configure( result["flow_id"], { CONF_API_KEY: "1234567890", @@ -51,9 +56,9 @@ async def test_form(hass: HomeAssistant) -> None: ) await hass.async_block_till_done() - assert result2["type"] == FlowResultType.CREATE_ENTRY - assert result2["title"] == "Stockholm C to Uppsala C at 10:00" - assert result2["data"] == { + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["title"] == "Stockholm C to Uppsala C at 10:00" + assert result["data"] == { "api_key": "1234567890", "name": "Stockholm C to Uppsala C at 10:00", "from": "Stockholm C", @@ -62,7 +67,7 @@ async def test_form(hass: HomeAssistant) -> None: "weekday": ["mon", "fri"], } assert len(mock_setup_entry.mock_calls) == 1 - assert result2["result"].unique_id == "{}-{}-{}-{}".format( + assert result["result"].unique_id == "{}-{}-{}-{}".format( "stockholmc", "uppsalac", "10:00", "['mon', 'fri']" ) @@ -92,11 +97,13 @@ async def test_form_entry_already_exist(hass: HomeAssistant) -> None: with patch( "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", + ), patch( + "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop", ), patch( "homeassistant.components.trafikverket_train.async_setup_entry", return_value=True, ): - result2 = await hass.config_entries.flow.async_configure( + result = await hass.config_entries.flow.async_configure( result["flow_id"], { CONF_API_KEY: "1234567890", @@ -108,8 +115,8 @@ async def test_form_entry_already_exist(hass: HomeAssistant) -> None: ) await hass.async_block_till_done() - assert result2["type"] == FlowResultType.ABORT - assert result2["reason"] == "already_configured" + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == "already_configured" @pytest.mark.parametrize( @@ -137,19 +144,21 @@ async def test_flow_fails( hass: HomeAssistant, side_effect: Exception, base_error: str ) -> None: """Test config flow errors.""" - result4 = await hass.config_entries.flow.async_init( + result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) - assert result4["type"] == FlowResultType.FORM - assert result4["step_id"] == config_entries.SOURCE_USER + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == config_entries.SOURCE_USER with patch( "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", side_effect=side_effect(), + ), patch( + "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop", ): - result4 = await hass.config_entries.flow.async_configure( - result4["flow_id"], + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={ CONF_API_KEY: "1234567890", CONF_FROM: "Stockholm C", @@ -157,32 +166,55 @@ async def test_flow_fails( }, ) - assert result4["errors"] == {"base": base_error} + assert result["errors"] == {"base": base_error} -async def test_flow_fails_incorrect_time(hass: HomeAssistant) -> None: - """Test config flow errors due to bad time.""" - result5 = await hass.config_entries.flow.async_init( +@pytest.mark.parametrize( + ("side_effect", "base_error"), + [ + ( + NoTrainAnnouncementFound, + "no_trains", + ), + ( + MultipleTrainAnnouncementFound, + "multiple_trains", + ), + ( + UnknownError, + "cannot_connect", + ), + ], +) +async def test_flow_fails_departures( + hass: HomeAssistant, side_effect: Exception, base_error: str +) -> None: + """Test config flow errors.""" + result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) - assert result5["type"] == FlowResultType.FORM - assert result5["step_id"] == config_entries.SOURCE_USER + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == config_entries.SOURCE_USER with patch( "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", + ), patch( + "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_next_train_stop", + side_effect=side_effect(), + ), patch( + "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop", ): - result6 = await hass.config_entries.flow.async_configure( - result5["flow_id"], + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={ CONF_API_KEY: "1234567890", CONF_FROM: "Stockholm C", CONF_TO: "Uppsala C", - CONF_TIME: "25:25", }, ) - assert result6["errors"] == {"base": "invalid_time"} + assert result["errors"] == {"base": base_error} async def test_reauth_flow(hass: HomeAssistant) -> None: @@ -216,18 +248,20 @@ async def test_reauth_flow(hass: HomeAssistant) -> None: with patch( "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", + ), patch( + "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop", ), patch( "homeassistant.components.trafikverket_train.async_setup_entry", return_value=True, ): - result2 = await hass.config_entries.flow.async_configure( + result = await hass.config_entries.flow.async_configure( result["flow_id"], {CONF_API_KEY: "1234567891"}, ) await hass.async_block_till_done() - assert result2["type"] == FlowResultType.ABORT - assert result2["reason"] == "reauth_successful" + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == "reauth_successful" assert entry.data == { "api_key": "1234567891", "name": "Stockholm C to Uppsala C at 10:00", @@ -290,31 +324,122 @@ async def test_reauth_flow_error( with patch( "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", side_effect=side_effect(), + ), patch( + "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop", ): - result2 = await hass.config_entries.flow.async_configure( + result = await hass.config_entries.flow.async_configure( result["flow_id"], {CONF_API_KEY: "1234567890"}, ) await hass.async_block_till_done() - assert result2["step_id"] == "reauth_confirm" - assert result2["type"] == FlowResultType.FORM - assert result2["errors"] == {"base": p_error} + assert result["step_id"] == "reauth_confirm" + assert result["type"] == FlowResultType.FORM + assert result["errors"] == {"base": p_error} with patch( "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", + ), patch( + "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop", ), patch( "homeassistant.components.trafikverket_train.async_setup_entry", return_value=True, ): - result2 = await hass.config_entries.flow.async_configure( + result = await hass.config_entries.flow.async_configure( result["flow_id"], {CONF_API_KEY: "1234567891"}, ) await hass.async_block_till_done() - assert result2["type"] == FlowResultType.ABORT - assert result2["reason"] == "reauth_successful" + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + assert entry.data == { + "api_key": "1234567891", + "name": "Stockholm C to Uppsala C at 10:00", + "from": "Stockholm C", + "to": "Uppsala C", + "time": "10:00", + "weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"], + } + + +@pytest.mark.parametrize( + ("side_effect", "p_error"), + [ + ( + NoTrainAnnouncementFound, + "no_trains", + ), + ( + MultipleTrainAnnouncementFound, + "multiple_trains", + ), + ( + UnknownError, + "cannot_connect", + ), + ], +) +async def test_reauth_flow_error_departures( + hass: HomeAssistant, side_effect: Exception, p_error: str +) -> None: + """Test a reauthentication flow with error.""" + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_API_KEY: "1234567890", + CONF_NAME: "Stockholm C to Uppsala C at 10:00", + CONF_FROM: "Stockholm C", + CONF_TO: "Uppsala C", + CONF_TIME: "10:00", + CONF_WEEKDAY: WEEKDAYS, + }, + unique_id=f"stockholmc-uppsalac-10:00-{WEEKDAYS}", + ) + entry.add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={ + "source": config_entries.SOURCE_REAUTH, + "unique_id": entry.unique_id, + "entry_id": entry.entry_id, + }, + data=entry.data, + ) + + with patch( + "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", + ), patch( + "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop", + side_effect=side_effect(), + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {CONF_API_KEY: "1234567890"}, + ) + await hass.async_block_till_done() + + assert result["step_id"] == "reauth_confirm" + assert result["type"] == FlowResultType.FORM + assert result["errors"] == {"base": p_error} + + with patch( + "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", + ), patch( + "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop", + ), patch( + "homeassistant.components.trafikverket_train.async_setup_entry", + return_value=True, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {CONF_API_KEY: "1234567891"}, + ) + await hass.async_block_till_done() + + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == "reauth_successful" assert entry.data == { "api_key": "1234567891", "name": "Stockholm C to Uppsala C at 10:00",