diff --git a/homeassistant/components/swiss_public_transport/config_flow.py b/homeassistant/components/swiss_public_transport/config_flow.py index 5133e9374b16ef..10d8d6ae1e3ee1 100644 --- a/homeassistant/components/swiss_public_transport/config_flow.py +++ b/homeassistant/components/swiss_public_transport/config_flow.py @@ -15,6 +15,9 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.selector import ( DurationSelector, + SelectSelector, + SelectSelectorConfig, + SelectSelectorMode, TextSelector, TextSelectorConfig, TextSelectorType, @@ -26,11 +29,14 @@ CONF_IS_ARRIVAL, CONF_START, CONF_TIME, + CONF_TIME_MODE, CONF_TIME_OFFSET, CONF_VIA, + DEFAULT_TIME_MODE, DOMAIN, MAX_VIA, PLACEHOLDERS, + TIME_MODE_OPTIONS, ) from .helper import ( dict_duration_to_str_duration, @@ -38,7 +44,7 @@ unique_id_from_config, ) -DATA_SCHEMA = vol.Schema( +USER_DATA_SCHEMA = vol.Schema( { vol.Required(CONF_START): cv.string, vol.Optional(CONF_VIA): TextSelector( @@ -48,11 +54,19 @@ ), ), vol.Required(CONF_DESTINATION): cv.string, - vol.Optional(CONF_TIME): TimeSelector(), - vol.Optional(CONF_TIME_OFFSET): DurationSelector(), vol.Optional(CONF_IS_ARRIVAL): bool, + vol.Optional(CONF_TIME_MODE): SelectSelector( + SelectSelectorConfig( + options=TIME_MODE_OPTIONS, + mode=SelectSelectorMode.DROPDOWN, + translation_key="time_mode", + ), + ), } ) +ADVANCED_TIME_DATA_SCHEMA = {vol.Optional(CONF_TIME): TimeSelector()} +ADVANCED_TIME_OFFSET_DATA_SCHEMA = {vol.Optional(CONF_TIME_OFFSET): DurationSelector()} + _LOGGER = logging.getLogger(__name__) @@ -63,30 +77,18 @@ class SwissPublicTransportConfigFlow(ConfigFlow, domain=DOMAIN): VERSION = 3 MINOR_VERSION = 1 + user_input: dict[str, Any] + async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Async user step to set up the connection.""" errors: dict[str, str] = {} if user_input is not None: - unique_id = unique_id_from_config(user_input) - await self.async_set_unique_id(unique_id) - self._abort_if_unique_id_configured() - if CONF_VIA in user_input and len(user_input[CONF_VIA]) > MAX_VIA: errors["base"] = "too_many_via_stations" - elif CONF_TIME in user_input and CONF_TIME_OFFSET in user_input: - errors["base"] = "mutex_time_offset" else: session = async_get_clientsession(self.hass) - time_offset_dict: dict[str, int] | None = user_input.get( - CONF_TIME_OFFSET - ) - time_offset = ( - dict_duration_to_str_duration(time_offset_dict) - if CONF_TIME_OFFSET in user_input and time_offset_dict is not None - else None - ) opendata = OpendataTransport( user_input[CONF_START], user_input[CONF_DESTINATION], @@ -94,26 +96,95 @@ async def async_step_user( via=user_input.get(CONF_VIA), time=user_input.get(CONF_TIME), ) - if time_offset: - offset_opendata(opendata, time_offset) - try: - await opendata.async_get_data() - except OpendataTransportConnectionError: - errors["base"] = "cannot_connect" - except OpendataTransportError: - errors["base"] = "bad_config" - except Exception: # pylint: disable=broad-except - _LOGGER.exception("Unknown error") - errors["base"] = "unknown" + err = await self.fetch_connections(opendata) + if err: + errors["base"] = err else: - return self.async_create_entry( - title=unique_id, - data=user_input, + if user_input[CONF_TIME_MODE] == "now": + unique_id = unique_id_from_config(user_input) + await self.async_set_unique_id(unique_id) + self._abort_if_unique_id_configured() + return self.async_create_entry( + title=unique_id, + data=user_input, + ) + self.user_input = user_input + return self.async_show_form( + step_id="advanced", + data_schema=self.build_advanced_schema(user_input), + errors=errors, + description_placeholders=PLACEHOLDERS, ) return self.async_show_form( step_id="user", - data_schema=DATA_SCHEMA, + data_schema=self.add_suggested_values_to_schema( + data_schema=USER_DATA_SCHEMA, + suggested_values=user_input or {CONF_TIME_MODE: DEFAULT_TIME_MODE}, + ), errors=errors, description_placeholders=PLACEHOLDERS, ) + + async def async_step_advanced( + self, advanced_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Async advanced step to set up the connection.""" + errors: dict[str, str] = {} + if advanced_input is not None: + unique_id = unique_id_from_config({**self.user_input, **advanced_input}) + await self.async_set_unique_id(unique_id) + self._abort_if_unique_id_configured() + + session = async_get_clientsession(self.hass) + time_offset_dict: dict[str, int] | None = advanced_input.get( + CONF_TIME_OFFSET + ) + time_offset = ( + dict_duration_to_str_duration(time_offset_dict) + if CONF_TIME_OFFSET in advanced_input and time_offset_dict is not None + else None + ) + opendata = OpendataTransport( + self.user_input[CONF_START], + self.user_input[CONF_DESTINATION], + session, + via=self.user_input.get(CONF_VIA), + time=advanced_input.get(CONF_TIME), + ) + if time_offset: + offset_opendata(opendata, time_offset) + err = await self.fetch_connections(opendata) + if err: + errors["base"] = err + else: + return self.async_create_entry( + title=unique_id, + data={**self.user_input, **advanced_input}, + ) + + return self.async_show_form( + step_id="advanced", + data_schema=self.build_advanced_schema(self.user_input), + errors=errors, + description_placeholders=PLACEHOLDERS, + ) + + async def fetch_connections(self, opendata: OpendataTransport) -> str | None: + """Fetch the connections and advancedly return an error.""" + try: + await opendata.async_get_data() + except OpendataTransportConnectionError: + return "cannot_connect" + except OpendataTransportError: + return "bad_config" + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Unknown error") + return "unknown" + return None + + def build_advanced_schema(self, user_input: dict[str, Any]) -> vol.Schema: + """Build the advanced schema.""" + if user_input[CONF_TIME_MODE] == "fixed": + return vol.Schema(ADVANCED_TIME_DATA_SCHEMA) + return vol.Schema(ADVANCED_TIME_OFFSET_DATA_SCHEMA) diff --git a/homeassistant/components/swiss_public_transport/const.py b/homeassistant/components/swiss_public_transport/const.py index be18ac1b1ea584..027b3cd7430a74 100644 --- a/homeassistant/components/swiss_public_transport/const.py +++ b/homeassistant/components/swiss_public_transport/const.py @@ -7,17 +7,20 @@ CONF_DESTINATION: Final = "to" CONF_START: Final = "from" CONF_VIA: Final = "via" +CONF_IS_ARRIVAL: Final = "is_arrival" +CONF_TIME_MODE: Final = "time_mode" CONF_TIME: Final = "time" CONF_TIME_OFFSET: Final = "time_offset" -CONF_IS_ARRIVAL: Final = "is_arrival" DEFAULT_NAME = "Next Destination" DEFAULT_UPDATE_TIME = 90 DEFAULT_IS_ARRIVAL = False +DEFAULT_TIME_MODE = "now" MAX_VIA = 5 CONNECTIONS_COUNT = 3 CONNECTIONS_MAX = 15 +TIME_MODE_OPTIONS = ["now", "fixed", "offset"] PLACEHOLDERS = { diff --git a/homeassistant/components/swiss_public_transport/strings.json b/homeassistant/components/swiss_public_transport/strings.json index aecc3e6ee4deee..436a77c8227b60 100644 --- a/homeassistant/components/swiss_public_transport/strings.json +++ b/homeassistant/components/swiss_public_transport/strings.json @@ -19,11 +19,17 @@ "from": "Start station", "to": "End station", "via": "List of up to 5 via stations", - "time": "Select a fixed time of day", - "time_offset": "Select a moving time offset", - "is_arrival": "Use arrival instead of departure for time and offset configuration" + "is_arrival": "Use arrival instead of departure", + "time_mode": "Select a time mode" + }, + "description": "Provide start and end station for your connection,\nand optionally up to 5 via stations.\nOptionally, you can also configure connections at a specific time or moving offset.\n\nCheck the [stationboard]({stationboard_url}) for valid stations.", + "title": "Swiss Public Transport" + }, + "optional": { + "data": { + "time": "Select the time of day", + "time_offset": "Select an offset duration" }, - "description": "Provide start and end station for your connection,\nand optionally up to 5 via stations.\nOptionally, you can also configure connections at a specific time or moving offset for departure or arrival.\n\nCheck the [stationboard]({stationboard_url}) for valid stations.", "title": "Swiss Public Transport" } } @@ -88,5 +94,14 @@ "config_entry_not_found": { "message": "Swiss public transport integration instance \"{target}\" not found." } + }, + "selector": { + "time_mode": { + "options": { + "now": "Now", + "fixed": "At a fixed time of day", + "offset": "At an offset from now" + } + } } } diff --git a/tests/components/swiss_public_transport/test_config_flow.py b/tests/components/swiss_public_transport/test_config_flow.py index 155f36fade19b7..03c996334ec806 100644 --- a/tests/components/swiss_public_transport/test_config_flow.py +++ b/tests/components/swiss_public_transport/test_config_flow.py @@ -14,6 +14,7 @@ CONF_IS_ARRIVAL, CONF_START, CONF_TIME, + CONF_TIME_MODE, CONF_TIME_OFFSET, CONF_VIA, MAX_VIA, @@ -26,62 +27,85 @@ pytestmark = pytest.mark.usefixtures("mock_setup_entry") -MOCK_DATA_STEP = { +MOCK_USER_DATA_STEP = { CONF_START: "test_start", CONF_DESTINATION: "test_destination", + CONF_TIME_MODE: "now", } -MOCK_DATA_STEP_ONE_VIA = { - **MOCK_DATA_STEP, +MOCK_USER_DATA_STEP_ONE_VIA = { + **MOCK_USER_DATA_STEP, CONF_VIA: ["via_station"], } -MOCK_DATA_STEP_MANY_VIA = { - **MOCK_DATA_STEP, +MOCK_USER_DATA_STEP_MANY_VIA = { + **MOCK_USER_DATA_STEP, CONF_VIA: ["via_station_1", "via_station_2", "via_station_3"], } -MOCK_DATA_STEP_TOO_MANY_STATIONS = { - **MOCK_DATA_STEP, - CONF_VIA: MOCK_DATA_STEP_ONE_VIA[CONF_VIA] * (MAX_VIA + 1), +MOCK_USER_DATA_STEP_TOO_MANY_STATIONS = { + **MOCK_USER_DATA_STEP, + CONF_VIA: MOCK_USER_DATA_STEP_ONE_VIA[CONF_VIA] * (MAX_VIA + 1), } -MOCK_DATA_STEP_TIME = { - **MOCK_DATA_STEP, - CONF_TIME: "18:03:00", +MOCK_USER_DATA_STEP_ARRIVAL = { + **MOCK_USER_DATA_STEP, + CONF_IS_ARRIVAL: True, } -MOCK_DATA_STEP_TIME_OFFSET_ARRIVAL = { - **MOCK_DATA_STEP, - CONF_TIME_OFFSET: {"hours": 0, "minutes": 10, "seconds": 0}, - CONF_IS_ARRIVAL: True, +MOCK_USER_DATA_STEP_TIME = { + **MOCK_USER_DATA_STEP, + CONF_TIME_MODE: "fixed", +} + +MOCK_USER_DATA_STEP_TIME_OFFSET = { + **MOCK_USER_DATA_STEP, + CONF_TIME_MODE: "offset", +} + +MOCK_USER_DATA_STEP_BAD = { + **MOCK_USER_DATA_STEP, + CONF_TIME_MODE: "bad", } -MOCK_DATA_STEP_TIME_OFFSET_MUTEX = { - **MOCK_DATA_STEP, +MOCK_ADVANCED_DATA_STEP_TIME = { CONF_TIME: "18:03:00", +} + +MOCK_ADVANCED_DATA_STEP_TIME_OFFSET = { CONF_TIME_OFFSET: {"hours": 0, "minutes": 10, "seconds": 0}, } @pytest.mark.parametrize( - ("user_input", "config_title"), + ("user_input", "advanced_input", "config_title"), [ - (MOCK_DATA_STEP, "test_start test_destination"), - (MOCK_DATA_STEP_ONE_VIA, "test_start test_destination via via_station"), + (MOCK_USER_DATA_STEP, None, "test_start test_destination"), ( - MOCK_DATA_STEP_MANY_VIA, + MOCK_USER_DATA_STEP_ONE_VIA, + None, + "test_start test_destination via via_station", + ), + ( + MOCK_USER_DATA_STEP_MANY_VIA, + None, "test_start test_destination via via_station_1, via_station_2, via_station_3", ), - (MOCK_DATA_STEP_TIME, "test_start test_destination at 18:03:00"), + (MOCK_USER_DATA_STEP_ARRIVAL, None, "test_start test_destination arrival"), ( - MOCK_DATA_STEP_TIME_OFFSET_ARRIVAL, - "test_start test_destination arrival in 00:10:00", + MOCK_USER_DATA_STEP_TIME, + MOCK_ADVANCED_DATA_STEP_TIME, + "test_start test_destination at 18:03:00", + ), + ( + MOCK_USER_DATA_STEP_TIME_OFFSET, + MOCK_ADVANCED_DATA_STEP_TIME_OFFSET, + "test_start test_destination in 00:10:00", ), ], ) async def test_flow_user_init_data_success( - hass: HomeAssistant, user_input, config_title + hass: HomeAssistant, user_input, advanced_input, config_title ) -> None: """Test success response.""" result = await hass.config_entries.flow.async_init( @@ -91,49 +115,53 @@ async def test_flow_user_init_data_success( assert result["type"] is FlowResultType.FORM assert result["step_id"] == "user" assert result["handler"] == "swiss_public_transport" - assert result["data_schema"] == config_flow.DATA_SCHEMA + assert result["data_schema"] == config_flow.USER_DATA_SCHEMA with patch( "homeassistant.components.swiss_public_transport.config_flow.OpendataTransport.async_get_data", autospec=True, return_value=True, ): - result = await hass.config_entries.flow.async_init( - config_flow.DOMAIN, context={"source": "user"} - ) result = await hass.config_entries.flow.async_configure( result["flow_id"], user_input=user_input, ) + if advanced_input: + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "advanced" + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=advanced_input, + ) + assert result["type"] == FlowResultType.CREATE_ENTRY assert result["result"].title == config_title - assert result["data"] == user_input + assert result["data"] == {**user_input, **(advanced_input or {})} @pytest.mark.parametrize( ("raise_error", "text_error", "user_input_error"), [ - (OpendataTransportConnectionError(), "cannot_connect", MOCK_DATA_STEP), - (OpendataTransportError(), "bad_config", MOCK_DATA_STEP), - (None, "too_many_via_stations", MOCK_DATA_STEP_TOO_MANY_STATIONS), - (None, "mutex_time_offset", MOCK_DATA_STEP_TIME_OFFSET_MUTEX), - (IndexError(), "unknown", MOCK_DATA_STEP), + (OpendataTransportConnectionError(), "cannot_connect", MOCK_USER_DATA_STEP), + (OpendataTransportError(), "bad_config", MOCK_USER_DATA_STEP), + (None, "too_many_via_stations", MOCK_USER_DATA_STEP_TOO_MANY_STATIONS), + (IndexError(), "unknown", MOCK_USER_DATA_STEP), ], ) -async def test_flow_user_init_data_error_and_recover( +async def test_flow_user_init_data_error_and_recover_on_step_1( hass: HomeAssistant, raise_error, text_error, user_input_error ) -> None: """Test unknown errors.""" + result = await hass.config_entries.flow.async_init( + config_flow.DOMAIN, context={"source": "user"} + ) with patch( "homeassistant.components.swiss_public_transport.config_flow.OpendataTransport.async_get_data", autospec=True, side_effect=raise_error, ) as mock_OpendataTransport: - result = await hass.config_entries.flow.async_init( - config_flow.DOMAIN, context={"source": "user"} - ) result = await hass.config_entries.flow.async_configure( result["flow_id"], user_input=user_input_error, @@ -147,13 +175,75 @@ async def test_flow_user_init_data_error_and_recover( mock_OpendataTransport.return_value = True result = await hass.config_entries.flow.async_configure( result["flow_id"], - user_input=MOCK_DATA_STEP, + user_input=MOCK_USER_DATA_STEP, ) assert result["type"] == FlowResultType.CREATE_ENTRY assert result["result"].title == "test_start test_destination" - assert result["data"] == MOCK_DATA_STEP + assert result["data"] == MOCK_USER_DATA_STEP + + +@pytest.mark.parametrize( + ("raise_error", "text_error", "user_input_error"), + [ + ( + OpendataTransportConnectionError(), + "cannot_connect", + MOCK_ADVANCED_DATA_STEP_TIME, + ), + (OpendataTransportError(), "bad_config", MOCK_ADVANCED_DATA_STEP_TIME), + (IndexError(), "unknown", MOCK_ADVANCED_DATA_STEP_TIME), + ], +) +async def test_flow_user_init_data_error_and_recover_on_step_2( + hass: HomeAssistant, raise_error, text_error, user_input_error +) -> None: + """Test unknown errors.""" + result = await hass.config_entries.flow.async_init( + config_flow.DOMAIN, context={"source": "user"} + ) + + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "user" + assert result["handler"] == "swiss_public_transport" + assert result["data_schema"] == config_flow.USER_DATA_SCHEMA + + with patch( + "homeassistant.components.swiss_public_transport.config_flow.OpendataTransport.async_get_data", + autospec=True, + return_value=True, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=MOCK_USER_DATA_STEP_TIME, + ) + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "advanced" + + with patch( + "homeassistant.components.swiss_public_transport.config_flow.OpendataTransport.async_get_data", + autospec=True, + side_effect=raise_error, + ) as mock_OpendataTransport: + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=user_input_error, + ) + + assert result["type"] is FlowResultType.FORM + assert result["errors"]["base"] == text_error + + # Recover + mock_OpendataTransport.side_effect = None + mock_OpendataTransport.return_value = True + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=user_input_error, + ) + + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["result"].title == "test_start test_destination at 18:03:00" async def test_flow_user_init_data_already_configured(hass: HomeAssistant) -> None: @@ -161,8 +251,8 @@ async def test_flow_user_init_data_already_configured(hass: HomeAssistant) -> No entry = MockConfigEntry( domain=config_flow.DOMAIN, - data=MOCK_DATA_STEP, - unique_id=unique_id_from_config(MOCK_DATA_STEP), + data=MOCK_USER_DATA_STEP, + unique_id=unique_id_from_config(MOCK_USER_DATA_STEP), ) entry.add_to_hass(hass) @@ -177,7 +267,7 @@ async def test_flow_user_init_data_already_configured(hass: HomeAssistant) -> No result = await hass.config_entries.flow.async_configure( result["flow_id"], - user_input=MOCK_DATA_STEP, + user_input=MOCK_USER_DATA_STEP, ) assert result["type"] is FlowResultType.ABORT