Improve sql config flow

This commit is contained in:
G Johansson
2025-08-16 13:52:03 +00:00
parent 7bd126dc8e
commit f7654938ab
5 changed files with 634 additions and 542 deletions

View File

@@ -6,7 +6,7 @@ import logging
from typing import Any
import sqlalchemy
from sqlalchemy.engine import Result
from sqlalchemy.engine import Engine, Result
from sqlalchemy.exc import MultipleResultsFound, NoSuchColumnError, SQLAlchemyError
from sqlalchemy.orm import Session, scoped_session, sessionmaker
import sqlparse
@@ -32,9 +32,10 @@ from homeassistant.const import (
CONF_VALUE_TEMPLATE,
)
from homeassistant.core import callback
from homeassistant.data_entry_flow import section
from homeassistant.helpers import selector
from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
from .util import resolve_db_url
_LOGGER = logging.getLogger(__name__)
@@ -42,40 +43,38 @@ _LOGGER = logging.getLogger(__name__)
OPTIONS_SCHEMA: vol.Schema = vol.Schema(
{
vol.Optional(
CONF_DB_URL,
): selector.TextSelector(),
vol.Required(
CONF_COLUMN_NAME,
): selector.TextSelector(),
vol.Required(
CONF_QUERY,
): selector.TextSelector(selector.TextSelectorConfig(multiline=True)),
vol.Optional(
CONF_UNIT_OF_MEASUREMENT,
): selector.TextSelector(),
vol.Optional(
CONF_VALUE_TEMPLATE,
): selector.TemplateSelector(),
vol.Optional(CONF_DEVICE_CLASS): selector.SelectSelector(
selector.SelectSelectorConfig(
options=[
cls.value
for cls in SensorDeviceClass
if cls != SensorDeviceClass.ENUM
],
mode=selector.SelectSelectorMode.DROPDOWN,
translation_key="device_class",
sort=True,
)
vol.Required(CONF_COLUMN_NAME): selector.TextSelector(),
vol.Required(CONF_QUERY): selector.TextSelector(
selector.TextSelectorConfig(multiline=True)
),
vol.Optional(CONF_STATE_CLASS): selector.SelectSelector(
selector.SelectSelectorConfig(
options=[cls.value for cls in SensorStateClass],
mode=selector.SelectSelectorMode.DROPDOWN,
translation_key="state_class",
sort=True,
)
vol.Required(CONF_ADVANCED_OPTIONS): section(
vol.Schema(
{
vol.Optional(CONF_VALUE_TEMPLATE): selector.TemplateSelector(),
vol.Optional(CONF_UNIT_OF_MEASUREMENT): selector.TextSelector(),
vol.Optional(CONF_DEVICE_CLASS): selector.SelectSelector(
selector.SelectSelectorConfig(
options=[
cls.value
for cls in SensorDeviceClass
if cls != SensorDeviceClass.ENUM
],
mode=selector.SelectSelectorMode.DROPDOWN,
translation_key="device_class",
sort=True,
)
),
vol.Optional(CONF_STATE_CLASS): selector.SelectSelector(
selector.SelectSelectorConfig(
options=[cls.value for cls in SensorStateClass],
mode=selector.SelectSelectorMode.DROPDOWN,
translation_key="state_class",
sort=True,
)
),
}
),
{"collapsed": True},
),
}
)
@@ -83,8 +82,9 @@ OPTIONS_SCHEMA: vol.Schema = vol.Schema(
CONFIG_SCHEMA: vol.Schema = vol.Schema(
{
vol.Required(CONF_NAME, default="Select SQL Query"): selector.TextSelector(),
vol.Optional(CONF_DB_URL): selector.TextSelector(),
}
).extend(OPTIONS_SCHEMA.schema)
)
def validate_sql_select(value: str) -> str:
@@ -99,6 +99,31 @@ def validate_sql_select(value: str) -> str:
return str(query[0])
def validate_db_connection(db_url: str) -> bool:
"""Validate db connection."""
engine: Engine | None = None
sess: Session | None = None
try:
engine = sqlalchemy.create_engine(db_url, future=True)
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
sess = sessmaker()
sess.execute(sqlalchemy.text("select 1 as value"))
except SQLAlchemyError as error:
_LOGGER.debug("Execution error %s", error)
if sess:
sess.close()
if engine:
engine.dispose()
raise
if sess:
sess.close()
engine.dispose()
return True
def validate_query(db_url: str, query: str, column: str) -> bool:
"""Validate SQL query."""
@@ -138,6 +163,8 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1
data: dict[str, Any]
@staticmethod
@callback
def async_get_options_flow(
@@ -151,17 +178,46 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
) -> ConfigFlowResult:
"""Handle the user step."""
errors = {}
description_placeholders = {}
if user_input is not None:
db_url = user_input.get(CONF_DB_URL)
try:
db_url_for_validation = resolve_db_url(self.hass, db_url)
await self.hass.async_add_executor_job(
validate_db_connection, db_url_for_validation
)
except SQLAlchemyError:
errors["db_url"] = "db_url_invalid"
if not errors:
self.data = {CONF_NAME: user_input[CONF_NAME]}
if db_url and db_url_for_validation != get_instance(self.hass).db_url:
self.data[CONF_DB_URL] = db_url
return await self.async_step_options()
return self.async_show_form(
step_id="user",
data_schema=self.add_suggested_values_to_schema(CONFIG_SCHEMA, user_input),
errors=errors,
)
async def async_step_options(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the user step."""
errors = {}
description_placeholders = {}
if user_input is not None:
query = user_input[CONF_QUERY]
column = user_input[CONF_COLUMN_NAME]
db_url_for_validation = None
try:
query = validate_sql_select(query)
db_url_for_validation = resolve_db_url(self.hass, db_url)
db_url_for_validation = resolve_db_url(
self.hass, self.data.get(CONF_DB_URL)
)
await self.hass.async_add_executor_job(
validate_query, db_url_for_validation, query, column
)
@@ -178,32 +234,22 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
_LOGGER.debug("Invalid query: %s", err)
errors["query"] = "query_invalid"
options = {
CONF_QUERY: query,
CONF_COLUMN_NAME: column,
CONF_NAME: user_input[CONF_NAME],
}
if uom := user_input.get(CONF_UNIT_OF_MEASUREMENT):
options[CONF_UNIT_OF_MEASUREMENT] = uom
if value_template := user_input.get(CONF_VALUE_TEMPLATE):
options[CONF_VALUE_TEMPLATE] = value_template
if device_class := user_input.get(CONF_DEVICE_CLASS):
options[CONF_DEVICE_CLASS] = device_class
if state_class := user_input.get(CONF_STATE_CLASS):
options[CONF_STATE_CLASS] = state_class
if db_url_for_validation != get_instance(self.hass).db_url:
options[CONF_DB_URL] = db_url_for_validation
for k, v in user_input[CONF_ADVANCED_OPTIONS].items():
if not v:
user_input[CONF_ADVANCED_OPTIONS].pop(k)
if not errors:
name = self.data[CONF_NAME]
self.data.pop(CONF_NAME)
return self.async_create_entry(
title=user_input[CONF_NAME],
data={},
options=options,
title=name,
data=self.data,
options=user_input,
)
return self.async_show_form(
step_id="user",
data_schema=self.add_suggested_values_to_schema(CONFIG_SCHEMA, user_input),
step_id="options",
data_schema=self.add_suggested_values_to_schema(OPTIONS_SCHEMA, user_input),
errors=errors,
description_placeholders=description_placeholders,
)
@@ -220,10 +266,9 @@ class SQLOptionsFlowHandler(OptionsFlowWithReload):
description_placeholders = {}
if user_input is not None:
db_url = user_input.get(CONF_DB_URL)
db_url = self.config_entry.data.get(CONF_DB_URL)
query = user_input[CONF_QUERY]
column = user_input[CONF_COLUMN_NAME]
name = self.config_entry.options.get(CONF_NAME, self.config_entry.title)
try:
query = validate_sql_select(query)
@@ -252,24 +297,12 @@ class SQLOptionsFlowHandler(OptionsFlowWithReload):
recorder_db,
)
options = {
CONF_QUERY: query,
CONF_COLUMN_NAME: column,
CONF_NAME: name,
}
if uom := user_input.get(CONF_UNIT_OF_MEASUREMENT):
options[CONF_UNIT_OF_MEASUREMENT] = uom
if value_template := user_input.get(CONF_VALUE_TEMPLATE):
options[CONF_VALUE_TEMPLATE] = value_template
if device_class := user_input.get(CONF_DEVICE_CLASS):
options[CONF_DEVICE_CLASS] = device_class
if state_class := user_input.get(CONF_STATE_CLASS):
options[CONF_STATE_CLASS] = state_class
if db_url_for_validation != get_instance(self.hass).db_url:
options[CONF_DB_URL] = db_url_for_validation
for k, v in user_input[CONF_ADVANCED_OPTIONS].items():
if not v:
user_input[CONF_ADVANCED_OPTIONS].pop(k)
return self.async_create_entry(
data=options,
data=user_input,
)
return self.async_show_form(

View File

@@ -9,4 +9,5 @@ PLATFORMS = [Platform.SENSOR]
CONF_COLUMN_NAME = "column"
CONF_QUERY = "query"
CONF_ADVANCED_OPTIONS = "advanced_options"
DB_URL_RE = re.compile("//.*:.*@")

View File

@@ -10,7 +10,12 @@ from homeassistant.components.sensor import (
SensorDeviceClass,
SensorStateClass,
)
from homeassistant.components.sql.const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
from homeassistant.components.sql.const import (
CONF_ADVANCED_OPTIONS,
CONF_COLUMN_NAME,
CONF_QUERY,
DOMAIN,
)
from homeassistant.config_entries import SOURCE_USER
from homeassistant.const import (
CONF_DEVICE_CLASS,
@@ -30,140 +35,167 @@ from homeassistant.helpers.trigger_template_entity import (
from tests.common import MockConfigEntry
ENTRY_CONFIG = {
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
},
}
ENTRY_CONFIG_WITH_VALUE_TEMPLATE = {
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_VALUE_TEMPLATE: "{{ value }}",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_VALUE_TEMPLATE: "{{ value }}",
},
}
ENTRY_CONFIG_INVALID_QUERY = {
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 FROM as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_INVALID_QUERY_2 = {
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT5 FROM as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_INVALID_QUERY_3 = {
CONF_NAME: "Get Value",
CONF_QUERY: ";;",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_INVALID_QUERY_OPT = {
CONF_QUERY: "SELECT 5 FROM as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_INVALID_QUERY_2_OPT = {
CONF_QUERY: "SELECT5 FROM as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_INVALID_QUERY_3_OPT = {
CONF_QUERY: ";;",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_QUERY_READ_ONLY_CTE = {
CONF_NAME: "Get Value",
CONF_QUERY: "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_QUERY_NO_READ_ONLY = {
CONF_NAME: "Get Value",
CONF_QUERY: "UPDATE states SET state = 999999 WHERE state_id = 11125",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE = {
CONF_NAME: "Get Value",
CONF_QUERY: "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_QUERY_READ_ONLY_CTE_OPT = {
CONF_QUERY: "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT = {
CONF_QUERY: "UPDATE 5 as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT = {
CONF_QUERY: "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_MULTIPLE_QUERIES = {
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as state; UPDATE states SET state = 10;",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_MULTIPLE_QUERIES_OPT = {
CONF_QUERY: "SELECT 5 as state; UPDATE states SET state = 10;",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_INVALID_COLUMN_NAME = {
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_INVALID_COLUMN_NAME_OPT = {
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "size",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
ENTRY_CONFIG_NO_RESULTS = {
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT kalle as value from no_table;",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_ADVANCED_OPTIONS: {
CONF_UNIT_OF_MEASUREMENT: "MiB",
},
}
YAML_CONFIG = {

View File

@@ -0,0 +1,41 @@
"""Fixtures for the SQL integration."""
from __future__ import annotations
from collections.abc import Generator
from pathlib import Path
import sqlite3
from unittest.mock import AsyncMock, patch
import pytest
from homeassistant.core import HomeAssistant
@pytest.fixture
def mock_setup_entry() -> Generator[AsyncMock]:
"""Override async_setup_entry."""
with patch(
"homeassistant.components.sql.async_setup_entry", return_value=True
) as mock_setup_entry:
yield mock_setup_entry
@pytest.fixture
async def create_db(
hass: HomeAssistant,
tmp_path: Path,
) -> str:
"""Test the SQL sensor with a query that returns no value."""
db_path = tmp_path / "test.db"
db_path_str = f"sqlite:///{db_path}"
def make_test_db():
"""Create a test database."""
conn = sqlite3.connect(db_path)
conn.execute("CREATE TABLE users (value INTEGER)")
conn.commit()
conn.close()
await hass.async_add_executor_job(make_test_db)
return db_path_str

File diff suppressed because it is too large Load Diff