From fcb1e3d501bffb9c7f48ce13fd7b2cd20f00942a Mon Sep 17 00:00:00 2001 From: G Johansson Date: Fri, 8 Aug 2025 15:56:46 +0000 Subject: [PATCH] Allow template in query in sql --- homeassistant/components/sql/__init__.py | 24 ++++----- homeassistant/components/sql/config_flow.py | 27 +++++----- homeassistant/components/sql/sensor.py | 55 +++++++++++++++------ homeassistant/components/sql/util.py | 30 +++++++++++ tests/components/sql/test_init.py | 32 +++++++----- tests/components/sql/test_sensor.py | 10 ++-- 6 files changed, 120 insertions(+), 58 deletions(-) diff --git a/homeassistant/components/sql/__init__.py b/homeassistant/components/sql/__init__.py index 33ed64be2bf..89921daba1b 100644 --- a/homeassistant/components/sql/__init__.py +++ b/homeassistant/components/sql/__init__.py @@ -4,7 +4,6 @@ from __future__ import annotations import logging -import sqlparse import voluptuous as vol from homeassistant.components.recorder import CONF_DB_URL, get_instance @@ -25,6 +24,7 @@ from homeassistant.const import ( ) from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv, discovery +from homeassistant.helpers.template import Template from homeassistant.helpers.trigger_template_entity import ( CONF_AVAILABILITY, CONF_PICTURE, @@ -33,28 +33,28 @@ from homeassistant.helpers.trigger_template_entity import ( from homeassistant.helpers.typing import ConfigType from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN, PLATFORMS -from .util import redact_credentials +from .util import check_and_render_sql_query, redact_credentials _LOGGER = logging.getLogger(__name__) -def validate_sql_select(value: str) -> str: +def validate_sql_select(value: Template) -> Template: """Validate that value is a SQL SELECT query.""" - if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1: - raise vol.Invalid("Multiple SQL queries are not supported") - if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN": - raise vol.Invalid("Invalid SQL query") - if query_type != "SELECT": - _LOGGER.debug("The SQL query %s is of type %s", query, query_type) - raise vol.Invalid("Only SELECT queries allowed") - return str(query[0]) + try: + assert value.hass + check_and_render_sql_query(value.hass, value) + except ValueError as err: + raise vol.Invalid(str(err)) from err + return value QUERY_SCHEMA = vol.Schema( { vol.Required(CONF_COLUMN_NAME): cv.string, vol.Required(CONF_NAME): cv.template, - vol.Required(CONF_QUERY): vol.All(cv.string, validate_sql_select), + vol.Required(CONF_QUERY): vol.All( + cv.template, ValueTemplate.from_template, validate_sql_select + ), vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, vol.Optional(CONF_VALUE_TEMPLATE): vol.All( cv.template, ValueTemplate.from_template diff --git a/homeassistant/components/sql/config_flow.py b/homeassistant/components/sql/config_flow.py index 37a6f9ef104..97ebfac4be2 100644 --- a/homeassistant/components/sql/config_flow.py +++ b/homeassistant/components/sql/config_flow.py @@ -9,7 +9,6 @@ import sqlalchemy from sqlalchemy.engine import Result from sqlalchemy.exc import MultipleResultsFound, NoSuchColumnError, SQLAlchemyError from sqlalchemy.orm import Session, scoped_session, sessionmaker -import sqlparse from sqlparse.exceptions import SQLParseError import voluptuous as vol @@ -31,11 +30,11 @@ from homeassistant.const import ( CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE, ) -from homeassistant.core import callback +from homeassistant.core import async_get_hass, callback from homeassistant.helpers import selector from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN -from .util import resolve_db_url +from .util import check_and_render_sql_query, resolve_db_url _LOGGER = logging.getLogger(__name__) @@ -89,14 +88,20 @@ CONFIG_SCHEMA: vol.Schema = vol.Schema( def validate_sql_select(value: str) -> str: """Validate that value is a SQL SELECT query.""" - if len(query := sqlparse.parse(value.lstrip().lstrip(";"))) > 1: - raise MultipleResultsFound - if len(query) == 0 or (query_type := query[0].get_type()) == "UNKNOWN": - raise ValueError - if query_type != "SELECT": - _LOGGER.debug("The SQL query %s is of type %s", query, query_type) - raise SQLParseError - return str(query[0]) + hass = async_get_hass() + try: + return check_and_render_sql_query(hass, value) + except ValueError as err: + err_text = err.args[0] + _LOGGER.debug("Invalid query '%s' results in '%s'", value, err_text) + if err_text == "Multiple SQL statements are not allowed": + raise MultipleResultsFound from err + if err_text in ( + "SQL query must be of type SELECT", + "SQL query must start with SELECT", + ): + raise SQLParseError from err + raise def validate_query(db_url: str, query: str, column: str) -> bool: diff --git a/homeassistant/components/sql/sensor.py b/homeassistant/components/sql/sensor.py index 8c0ba81d6d2..7f76c4f29ad 100644 --- a/homeassistant/components/sql/sensor.py +++ b/homeassistant/components/sql/sensor.py @@ -33,7 +33,7 @@ from homeassistant.const import ( MATCH_ALL, ) from homeassistant.core import Event, HomeAssistant, callback -from homeassistant.exceptions import TemplateError +from homeassistant.exceptions import PlatformNotReady, TemplateError from homeassistant.helpers import issue_registry as ir from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.entity_platform import ( @@ -51,7 +51,7 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN from .models import SQLData -from .util import redact_credentials, resolve_db_url +from .util import check_and_render_sql_query, redact_credentials, resolve_db_url _LOGGER = logging.getLogger(__name__) @@ -79,7 +79,7 @@ async def async_setup_platform( return name: Template = conf[CONF_NAME] - query_str: str = conf[CONF_QUERY] + query_template: ValueTemplate = conf[CONF_QUERY] value_template: ValueTemplate | None = conf.get(CONF_VALUE_TEMPLATE) column_name: str = conf[CONF_COLUMN_NAME] unique_id: str | None = conf.get(CONF_UNIQUE_ID) @@ -94,7 +94,7 @@ async def async_setup_platform( await async_setup_sensor( hass, trigger_entity_config, - query_str, + query_template, column_name, value_template, unique_id, @@ -117,6 +117,13 @@ async def async_setup_entry( template: str | None = entry.options.get(CONF_VALUE_TEMPLATE) column_name: str = entry.options[CONF_COLUMN_NAME] + query_template: ValueTemplate | None = None + try: + query_template = ValueTemplate(query_str, hass) + query_template.ensure_valid() + except TemplateError as err: + raise PlatformNotReady("Invalid SQL query template") from err + value_template: ValueTemplate | None = None if template is not None: try: @@ -135,7 +142,7 @@ async def async_setup_entry( await async_setup_sensor( hass, trigger_entity_config, - query_str, + query_template, column_name, value_template, entry.entry_id, @@ -178,7 +185,7 @@ def _async_get_or_init_domain_data(hass: HomeAssistant) -> SQLData: async def async_setup_sensor( hass: HomeAssistant, trigger_entity_config: ConfigType, - query_str: str, + query_template: ValueTemplate, column_name: str, value_template: ValueTemplate | None, unique_id: str | None, @@ -215,6 +222,7 @@ async def async_setup_sensor( else: return + query_str = check_and_render_sql_query(hass, query_template) upper_query = query_str.upper() if uses_recorder_db: redacted_query = redact_credentials(query_str) @@ -252,18 +260,23 @@ async def async_setup_sensor( ) # MSSQL uses TOP and not LIMIT - if not ("LIMIT" in upper_query or "SELECT TOP" in upper_query): + mod_query_template = query_template + if not ("LIMIT" in upper_query or upper_query.startswith("SELECT TOP")): if "mssql" in db_url: - query_str = upper_query.replace("SELECT", "SELECT TOP 1") + mod_query_template = ValueTemplate( + f"SELECT TOP 1{query_template.template[6:]}", hass + ) else: - query_str = query_str.replace(";", "") + " LIMIT 1;" + mod_query_template = ValueTemplate( + f"{query_template.template.replace(';', '')} LIMIT 1;", hass + ) async_add_entities( [ SQLSensor( trigger_entity_config, sessmaker, - query_str, + mod_query_template, column_name, value_template, yaml, @@ -315,7 +328,7 @@ class SQLSensor(ManualTriggerSensorEntity): self, trigger_entity_config: ConfigType, sessmaker: scoped_session, - query: str, + query: ValueTemplate, column: str, value_template: ValueTemplate | None, yaml: bool, @@ -329,7 +342,6 @@ class SQLSensor(ManualTriggerSensorEntity): self.sessionmaker = sessmaker self._attr_extra_state_attributes = {} self._use_database_executor = use_database_executor - self._lambda_stmt = _generate_lambda_stmt(query) if not yaml and (unique_id := trigger_entity_config.get(CONF_UNIQUE_ID)): self._attr_name = None self._attr_has_entity_name = True @@ -371,11 +383,22 @@ class SQLSensor(ManualTriggerSensorEntity): self._attr_extra_state_attributes = {} sess: scoped_session = self.sessionmaker() try: - result: Result = sess.execute(self._lambda_stmt) + rendered_query = check_and_render_sql_query(self.hass, self._query) + _lambda_stmt = _generate_lambda_stmt(rendered_query) + result: Result = sess.execute(_lambda_stmt) + except TemplateError as err: + _LOGGER.error( + "Error rendering query %s: %s", + redact_credentials(self._query.template), + redact_credentials(str(err)), + ) + sess.rollback() + sess.close() + return except SQLAlchemyError as err: _LOGGER.error( "Error executing query %s: %s", - self._query, + rendered_query, redact_credentials(str(err)), ) sess.rollback() @@ -383,7 +406,7 @@ class SQLSensor(ManualTriggerSensorEntity): return for res in result.mappings(): - _LOGGER.debug("Query %s result in %s", self._query, res.items()) + _LOGGER.debug("Query %s result in %s", rendered_query, res.items()) data = res[self._column_name] for key, value in res.items(): if isinstance(value, decimal.Decimal): @@ -410,6 +433,6 @@ class SQLSensor(ManualTriggerSensorEntity): self._attr_native_value = data if data is None: - _LOGGER.warning("%s returned no results", self._query) + _LOGGER.warning("%s returned no results", rendered_query) sess.close() diff --git a/homeassistant/components/sql/util.py b/homeassistant/components/sql/util.py index 48fb53820ff..39be6ce2fe3 100644 --- a/homeassistant/components/sql/util.py +++ b/homeassistant/components/sql/util.py @@ -4,8 +4,12 @@ from __future__ import annotations import logging +import sqlparse + from homeassistant.components.recorder import get_instance from homeassistant.core import HomeAssistant +from homeassistant.exceptions import TemplateError +from homeassistant.helpers.template import Template from .const import DB_URL_RE @@ -25,3 +29,29 @@ def resolve_db_url(hass: HomeAssistant, db_url: str | None) -> str: if db_url and not db_url.isspace(): return db_url return get_instance(hass).db_url + + +def check_and_render_sql_query(hass: HomeAssistant, query: Template | str) -> str: + """Check and render SQL query.""" + if isinstance(query, str): + query = query.strip() + if not query: + raise ValueError("Query cannot be empty") + query = Template(query, hass=hass) + + try: + query.ensure_valid() + rendered_query: str = query.async_render() + except TemplateError as err: + raise ValueError("Invalid template") from err + if len(rendered_queries := sqlparse.parse(rendered_query.lstrip().lstrip(";"))) > 1: + raise ValueError("Multiple SQL statements are not allowed") + if ( + len(rendered_queries) == 0 + or (query_type := rendered_queries[0].get_type()) == "UNKNOWN" + ): + raise ValueError("SQL query is empty or unknown type") + if query_type != "SELECT": + _LOGGER.debug("The SQL query %s is of type %s", rendered_query, query_type) + raise ValueError("SQL query must be of type SELECT") + return str(rendered_queries[0]) diff --git a/tests/components/sql/test_init.py b/tests/components/sql/test_init.py index 409ebca27c0..531495b7ccf 100644 --- a/tests/components/sql/test_init.py +++ b/tests/components/sql/test_init.py @@ -13,6 +13,7 @@ from homeassistant.components.sql import validate_sql_select from homeassistant.components.sql.const import DOMAIN from homeassistant.config_entries import ConfigEntryState from homeassistant.core import HomeAssistant +from homeassistant.helpers.template import Template from homeassistant.setup import async_setup_component from . import YAML_CONFIG_INVALID, YAML_CONFIG_NO_DB, init_integration @@ -56,34 +57,41 @@ async def test_setup_invalid_config( async def test_invalid_query(hass: HomeAssistant) -> None: """Test invalid query.""" - with pytest.raises(vol.Invalid): - validate_sql_select("DROP TABLE *") + with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"): + validate_sql_select(Template("DROP TABLE *", hass)) - with pytest.raises(vol.Invalid): - validate_sql_select("SELECT5 as value") + with pytest.raises(vol.Invalid, match="SQL query is empty or unknown type"): + validate_sql_select(Template("SELECT5 as value", hass)) - with pytest.raises(vol.Invalid): - validate_sql_select(";;") + with pytest.raises(vol.Invalid, match="SQL query is empty or unknown type"): + validate_sql_select(Template(";;", hass)) async def test_query_no_read_only(hass: HomeAssistant) -> None: """Test query no read only.""" - with pytest.raises(vol.Invalid): - validate_sql_select("UPDATE states SET state = 999999 WHERE state_id = 11125") + with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"): + validate_sql_select( + Template("UPDATE states SET state = 999999 WHERE state_id = 11125", hass) + ) async def test_query_no_read_only_cte(hass: HomeAssistant) -> None: """Test query no read only CTE.""" - with pytest.raises(vol.Invalid): + with pytest.raises(vol.Invalid, match="SQL query must be of type SELECT"): validate_sql_select( - "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;" + Template( + "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;", + hass, + ) ) async def test_multiple_queries(hass: HomeAssistant) -> None: """Test multiple queries.""" - with pytest.raises(vol.Invalid): - validate_sql_select("SELECT 5 as value; UPDATE states SET state = 10;") + with pytest.raises(vol.Invalid, match="Multiple SQL statements are not allowed"): + validate_sql_select( + Template("SELECT 5 as value; UPDATE states SET state = 10;", hass) + ) async def test_remove_configured_db_url_if_not_needed_when_not_needed( diff --git a/tests/components/sql/test_sensor.py b/tests/components/sql/test_sensor.py index 354840c518e..6800fbd58d3 100644 --- a/tests/components/sql/test_sensor.py +++ b/tests/components/sql/test_sensor.py @@ -26,7 +26,6 @@ from homeassistant.const import ( ) from homeassistant.core import HomeAssistant from homeassistant.helpers import issue_registry as ir -from homeassistant.helpers.entity_platform import async_get_platforms from homeassistant.setup import async_setup_component from homeassistant.util import dt as dt_util @@ -624,17 +623,14 @@ async def test_query_recover_from_rollback( "unique_id": "very_unique_id", } await init_integration(hass, config) - platforms = async_get_platforms(hass, "sql") - sql_entity = platforms[0].entities["sensor.select_value_sql_query"] state = hass.states.get("sensor.select_value_sql_query") assert state.state == "5" assert state.attributes["value"] == 5 - with patch.object( - sql_entity, - "_lambda_stmt", - _generate_lambda_stmt("Faulty syntax create operational issue"), + with patch( + "homeassistant.components.sql.sensor._generate_lambda_stmt", + return_value=_generate_lambda_stmt("Faulty syntax create operational issue"), ): freezer.tick(timedelta(minutes=1)) async_fire_time_changed(hass)