diff --git a/homeassistant/components/sql/config_flow.py b/homeassistant/components/sql/config_flow.py index 97ebfac4be2..cfa47dfe6cc 100644 --- a/homeassistant/components/sql/config_flow.py +++ b/homeassistant/components/sql/config_flow.py @@ -160,12 +160,12 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN): if user_input is not None: db_url = user_input.get(CONF_DB_URL) - query = user_input[CONF_QUERY] + user_query = user_input[CONF_QUERY] column = user_input[CONF_COLUMN_NAME] db_url_for_validation = None try: - query = validate_sql_select(query) + query = validate_sql_select(user_query) db_url_for_validation = resolve_db_url(self.hass, db_url) await self.hass.async_add_executor_job( validate_query, db_url_for_validation, query, column @@ -184,7 +184,7 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN): errors["query"] = "query_invalid" options = { - CONF_QUERY: query, + CONF_QUERY: user_query, CONF_COLUMN_NAME: column, CONF_NAME: user_input[CONF_NAME], } @@ -226,12 +226,12 @@ class SQLOptionsFlowHandler(OptionsFlowWithReload): if user_input is not None: db_url = user_input.get(CONF_DB_URL) - query = user_input[CONF_QUERY] + user_query = user_input[CONF_QUERY] column = user_input[CONF_COLUMN_NAME] name = self.config_entry.options.get(CONF_NAME, self.config_entry.title) try: - query = validate_sql_select(query) + query = validate_sql_select(user_query) db_url_for_validation = resolve_db_url(self.hass, db_url) await self.hass.async_add_executor_job( validate_query, db_url_for_validation, query, column @@ -258,7 +258,7 @@ class SQLOptionsFlowHandler(OptionsFlowWithReload): ) options = { - CONF_QUERY: query, + CONF_QUERY: user_query, CONF_COLUMN_NAME: column, CONF_NAME: name, } diff --git a/homeassistant/components/sql/util.py b/homeassistant/components/sql/util.py index 39be6ce2fe3..c8b1fd046b0 100644 --- a/homeassistant/components/sql/util.py +++ b/homeassistant/components/sql/util.py @@ -54,4 +54,5 @@ def check_and_render_sql_query(hass: HomeAssistant, query: Template | str) -> st if query_type != "SELECT": _LOGGER.debug("The SQL query %s is of type %s", rendered_query, query_type) raise ValueError("SQL query must be of type SELECT") + return str(rendered_queries[0]) diff --git a/tests/components/sql/__init__.py b/tests/components/sql/__init__.py index 5f91cba1d94..2f1ade2d00a 100644 --- a/tests/components/sql/__init__.py +++ b/tests/components/sql/__init__.py @@ -46,6 +46,29 @@ ENTRY_CONFIG_WITH_VALUE_TEMPLATE = { CONF_VALUE_TEMPLATE: "{{ value }}", } +ENTRY_CONFIG_WITH_QUERY_TEMPLATE = { + CONF_NAME: "Get Value", + CONF_QUERY: "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value", + CONF_COLUMN_NAME: "value", + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_VALUE_TEMPLATE: "{{ value }}", +} + +ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE = { + CONF_NAME: "Get Value", + CONF_QUERY: "SELECT {{ 5 as value", + CONF_COLUMN_NAME: "value", + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_VALUE_TEMPLATE: "{{ value }}", +} + +ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE_OPT = { + CONF_QUERY: "SELECT {{ 5 as value", + CONF_COLUMN_NAME: "value", + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_VALUE_TEMPLATE: "{{ value }}", +} + ENTRY_CONFIG_INVALID_QUERY = { CONF_NAME: "Get Value", CONF_QUERY: "SELECT 5 FROM as value", diff --git a/tests/components/sql/test_config_flow.py b/tests/components/sql/test_config_flow.py index 3f2400c0a32..9d4e957db0b 100644 --- a/tests/components/sql/test_config_flow.py +++ b/tests/components/sql/test_config_flow.py @@ -31,6 +31,9 @@ from . import ( ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE, ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT, ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT, + ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE, + ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE_OPT, + ENTRY_CONFIG_WITH_QUERY_TEMPLATE, ENTRY_CONFIG_WITH_VALUE_TEMPLATE, ) @@ -69,6 +72,79 @@ async def test_form(recorder_mock: Recorder, hass: HomeAssistant) -> None: assert len(mock_setup_entry.mock_calls) == 1 +async def test_form_with_query_template( + recorder_mock: Recorder, hass: HomeAssistant +) -> None: + """Test for with query template.""" + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + assert result["errors"] == {} + + with patch( + "homeassistant.components.sql.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + ENTRY_CONFIG_WITH_QUERY_TEMPLATE, + ) + await hass.async_block_till_done() + + assert result2["type"] is FlowResultType.CREATE_ENTRY + assert result2["title"] == "Get Value" + assert result2["options"] == { + "name": "Get Value", + "query": "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value", + "column": "value", + "unit_of_measurement": "MiB", + "value_template": "{{ value }}", + } + assert len(mock_setup_entry.mock_calls) == 1 + + +async def test_form_with_broken_query_template( + recorder_mock: Recorder, hass: HomeAssistant +) -> None: + """Test form with broken query template.""" + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + assert result["errors"] == {} + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE, + ) + + assert result["type"] is FlowResultType.FORM + assert result["errors"] == {"query": "query_invalid"} + + with patch( + "homeassistant.components.sql.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + ENTRY_CONFIG_WITH_QUERY_TEMPLATE, + ) + + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["title"] == "Get Value" + assert result["options"] == { + "name": "Get Value", + "query": "SELECT {% if states('sensor.input1')=='on' %} 5 {% else %} 6 {% endif %} as value", + "column": "value", + "unit_of_measurement": "MiB", + "value_template": "{{ value }}", + } + assert len(mock_setup_entry.mock_calls) == 1 + + async def test_form_with_value_template( recorder_mock: Recorder, hass: HomeAssistant ) -> None: @@ -432,67 +508,77 @@ async def test_options_flow_fails_invalid_query( result = await hass.config_entries.options.async_init(entry.entry_id) - result2 = await hass.config_entries.options.async_configure( + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input=ENTRY_CONFIG_INVALID_QUERY_OPT, ) - assert result2["type"] is FlowResultType.FORM - assert result2["errors"] == { + assert result["type"] is FlowResultType.FORM + assert result["errors"] == { "query": "query_invalid", } - result3 = await hass.config_entries.options.async_configure( + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input=ENTRY_CONFIG_INVALID_QUERY_2_OPT, ) - assert result3["type"] is FlowResultType.FORM - assert result3["errors"] == { + assert result["type"] is FlowResultType.FORM + assert result["errors"] == { "query": "query_invalid", } - result3 = await hass.config_entries.options.async_configure( + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input=ENTRY_CONFIG_INVALID_QUERY_3_OPT, ) - assert result3["type"] is FlowResultType.FORM - assert result3["errors"] == { + assert result["type"] is FlowResultType.FORM + assert result["errors"] == { "query": "query_invalid", } - result2 = await hass.config_entries.options.async_configure( + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT, ) - assert result2["type"] is FlowResultType.FORM - assert result2["errors"] == { + assert result["type"] is FlowResultType.FORM + assert result["errors"] == { "query": "query_no_read_only", } - result3 = await hass.config_entries.options.async_configure( + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input=ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT, ) - assert result3["type"] is FlowResultType.FORM - assert result3["errors"] == { + assert result["type"] is FlowResultType.FORM + assert result["errors"] == { "query": "query_no_read_only", } - result3 = await hass.config_entries.options.async_configure( + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input=ENTRY_CONFIG_MULTIPLE_QUERIES_OPT, ) - assert result3["type"] is FlowResultType.FORM - assert result3["errors"] == { + assert result["type"] is FlowResultType.FORM + assert result["errors"] == { "query": "multiple_queries", } - result4 = await hass.config_entries.options.async_configure( + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input=ENTRY_CONFIG_WITH_BROKEN_QUERY_TEMPLATE_OPT, + ) + + assert result["type"] is FlowResultType.FORM + assert result["errors"] == { + "query": "query_invalid", + } + + result = await hass.config_entries.options.async_configure( result["flow_id"], user_input={ "db_url": "sqlite://", @@ -502,8 +588,8 @@ async def test_options_flow_fails_invalid_query( }, ) - assert result4["type"] is FlowResultType.CREATE_ENTRY - assert result4["data"] == { + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["data"] == { "name": "Get Value", "query": "SELECT 5 as size", "column": "size",