Add preview to sql config flow

This commit is contained in:
G Johansson
2025-08-09 15:26:32 +00:00
parent f8d3bc1b89
commit c3773a79ce
4 changed files with 525 additions and 17 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping
import logging import logging
from typing import Any from typing import Any
@@ -13,9 +14,11 @@ import sqlparse
from sqlparse.exceptions import SQLParseError from sqlparse.exceptions import SQLParseError
import voluptuous as vol import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components.recorder import CONF_DB_URL, get_instance from homeassistant.components.recorder import CONF_DB_URL, get_instance
from homeassistant.components.sensor import ( from homeassistant.components.sensor import (
CONF_STATE_CLASS, CONF_STATE_CLASS,
DOMAIN as SENSOR_DOMAIN,
SensorDeviceClass, SensorDeviceClass,
SensorStateClass, SensorStateClass,
) )
@@ -31,10 +34,15 @@ from homeassistant.const import (
CONF_UNIT_OF_MEASUREMENT, CONF_UNIT_OF_MEASUREMENT,
CONF_VALUE_TEMPLATE, CONF_VALUE_TEMPLATE,
) )
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import selector from homeassistant.helpers import selector
from homeassistant.helpers.entity_platform import PlatformData
from homeassistant.helpers.template import Template
from homeassistant.helpers.trigger_template_entity import ValueTemplate
from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
from .sensor import TRIGGER_ENTITY_OPTIONS, SQLSensor, get_db_connection
from .util import resolve_db_url from .util import resolve_db_url
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -138,6 +146,11 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
@staticmethod
async def async_setup_preview(hass: HomeAssistant) -> None:
"""Set up preview WS API."""
websocket_api.async_register_command(hass, ws_start_preview)
@staticmethod @staticmethod
@callback @callback
def async_get_options_flow( def async_get_options_flow(
@@ -206,6 +219,7 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
data_schema=self.add_suggested_values_to_schema(CONFIG_SCHEMA, user_input), data_schema=self.add_suggested_values_to_schema(CONFIG_SCHEMA, user_input),
errors=errors, errors=errors,
description_placeholders=description_placeholders, description_placeholders=description_placeholders,
preview="sql",
) )
@@ -279,4 +293,103 @@ class SQLOptionsFlowHandler(OptionsFlowWithReload):
), ),
errors=errors, errors=errors,
description_placeholders=description_placeholders, description_placeholders=description_placeholders,
preview="sql",
)
@staticmethod
async def async_setup_preview(hass: HomeAssistant) -> None:
"""Set up preview WS API."""
websocket_api.async_register_command(hass, ws_start_preview)
@websocket_api.websocket_command(
{
vol.Required("type"): "sql/start_preview",
vol.Required("flow_id"): str,
vol.Required("flow_type"): vol.Any("config_flow", "options_flow"),
vol.Required("user_input"): dict,
}
)
@websocket_api.async_response
async def ws_start_preview(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Generate a preview."""
if msg["flow_type"] == "config_flow":
flow_status = hass.config_entries.flow.async_get(msg["flow_id"])
flow_sets = hass.config_entries.flow._handler_progress_index.get( # noqa: SLF001
flow_status["handler"]
)
assert flow_sets
config_entry = hass.config_entries.async_get_entry(flow_status["handler"])
name = msg["user_input"][CONF_NAME]
else:
flow_status = hass.config_entries.options.async_get(msg["flow_id"])
config_entry = hass.config_entries.async_get_entry(flow_status["handler"])
if not config_entry:
raise HomeAssistantError("Config entry not found")
name = config_entry.options[CONF_NAME]
@callback
def async_preview_updated(state: str, attributes: Mapping[str, Any]) -> None:
"""Forward config entry state events to websocket."""
connection.send_message(
websocket_api.event_message(
msg["id"], {"attributes": attributes, "state": state}
)
)
db_url = resolve_db_url(hass, msg["user_input"].get(CONF_DB_URL))
if (
db_connection := await get_db_connection(
hass,
db_url,
)
) is None:
return # Missing test
sessmaker = db_connection[0]
use_database_executor = db_connection[1]
name_template = Template(name, hass)
trigger_entity_config = {CONF_NAME: name_template}
for key in TRIGGER_ENTITY_OPTIONS:
if key not in msg["user_input"]:
continue
trigger_entity_config[key] = msg["user_input"][key]
query_str: str = msg["user_input"].get(CONF_QUERY)
template: str | None = msg["user_input"].get(CONF_VALUE_TEMPLATE)
column_name: str = msg["user_input"].get(CONF_COLUMN_NAME)
value_template: ValueTemplate | None = None
if template is not None:
try:
value_template = ValueTemplate(template, hass)
value_template.ensure_valid()
except TemplateError:
value_template = None
preview_entity = SQLSensor(
trigger_entity_config=trigger_entity_config,
sessmaker=sessmaker,
query=query_str,
column=column_name,
value_template=value_template,
yaml=False,
use_database_executor=use_database_executor,
)
preview_entity.hass = hass
# Create PlatformData, needed for name translations
platform_data = PlatformData(hass=hass, domain=SENSOR_DOMAIN, platform_name=DOMAIN)
await platform_data.async_load_translations()
connection.send_result(msg["id"])
connection.subscriptions[msg["id"]] = await preview_entity.async_start_preview(
async_preview_updated
) )

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Mapping
from datetime import date from datetime import date
import decimal import decimal
import logging import logging
@@ -10,7 +11,7 @@ from typing import Any
import sqlalchemy import sqlalchemy
from sqlalchemy import lambda_stmt from sqlalchemy import lambda_stmt
from sqlalchemy.engine import Result from sqlalchemy.engine import Result
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import NoSuchColumnError, SQLAlchemyError
from sqlalchemy.orm import Session, scoped_session, sessionmaker from sqlalchemy.orm import Session, scoped_session, sessionmaker
from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.util import LRUCache from sqlalchemy.util import LRUCache
@@ -32,7 +33,7 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_STOP,
MATCH_ALL, MATCH_ALL,
) )
from homeassistant.core import Event, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
from homeassistant.exceptions import TemplateError from homeassistant.exceptions import TemplateError
from homeassistant.helpers import issue_registry as ir from homeassistant.helpers import issue_registry as ir
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
@@ -175,18 +176,11 @@ def _async_get_or_init_domain_data(hass: HomeAssistant) -> SQLData:
return sql_data return sql_data
async def async_setup_sensor( async def get_db_connection(
hass: HomeAssistant, hass: HomeAssistant,
trigger_entity_config: ConfigType,
query_str: str,
column_name: str,
value_template: ValueTemplate | None,
unique_id: str | None,
db_url: str, db_url: str,
yaml: bool, ) -> tuple[scoped_session, bool, bool] | None:
async_add_entities: AddEntitiesCallback | AddConfigEntryEntitiesCallback, """Get a database connection."""
) -> None:
"""Set up the SQL sensor."""
try: try:
instance = get_instance(hass) instance = get_instance(hass)
except KeyError: # No recorder loaded except KeyError: # No recorder loaded
@@ -213,7 +207,33 @@ async def async_setup_sensor(
): ):
sql_data.session_makers_by_db_url[db_url] = sessmaker sql_data.session_makers_by_db_url[db_url] = sessmaker
else: else:
return None
return sessmaker, use_database_executor, uses_recorder_db
async def async_setup_sensor(
hass: HomeAssistant,
trigger_entity_config: ConfigType,
query_str: str,
column_name: str,
value_template: ValueTemplate | None,
unique_id: str | None,
db_url: str,
yaml: bool,
async_add_entities: AddEntitiesCallback | AddConfigEntryEntitiesCallback,
) -> None:
"""Set up the SQL sensor."""
if (
db_connection := await get_db_connection(
hass,
db_url,
)
) is None:
return return
sessmaker = db_connection[0]
use_database_executor = db_connection[1]
uses_recorder_db = db_connection[2]
upper_query = query_str.upper() upper_query = query_str.upper()
if uses_recorder_db: if uses_recorder_db:
@@ -339,6 +359,7 @@ class SQLSensor(ManualTriggerSensorEntity):
manufacturer="SQL", manufacturer="SQL",
name=self._rendered.get(CONF_NAME), name=self._rendered.get(CONF_NAME),
) )
self._preview_callback: Callable[[str, Mapping[str, Any]], None] | None = None
@property @property
def name(self) -> str | None: def name(self) -> str | None:
@@ -357,12 +378,32 @@ class SQLSensor(ManualTriggerSensorEntity):
"""Return extra attributes.""" """Return extra attributes."""
return dict(self._attr_extra_state_attributes) return dict(self._attr_extra_state_attributes)
async def async_start_preview(
self,
preview_callback: Callable[[str, Mapping[str, Any]], None],
) -> CALLBACK_TYPE:
"""Render a preview."""
# abort early if there is needed data missing
if not self._query or not self._column_name:
self._attr_available = False
calculated_state = self._async_calculate_state()
preview_callback(calculated_state.state, calculated_state.attributes)
return self._call_on_remove_callbacks
self._preview_callback = preview_callback
await self.async_update()
return self._call_on_remove_callbacks
async def async_update(self) -> None: async def async_update(self) -> None:
"""Retrieve sensor data from the query using the right executor.""" """Retrieve sensor data from the query using the right executor."""
if self._use_database_executor: if self._use_database_executor:
await get_instance(self.hass).async_add_executor_job(self._update) await get_instance(self.hass).async_add_executor_job(self._update)
else: else:
await self.hass.async_add_executor_job(self._update) await self.hass.async_add_executor_job(self._update)
if self._preview_callback:
calculated_state = self._async_calculate_state()
self._preview_callback(calculated_state.state, calculated_state.attributes)
def _update(self) -> None: def _update(self) -> None:
"""Retrieve sensor data from the query.""" """Retrieve sensor data from the query."""
@@ -384,7 +425,19 @@ class SQLSensor(ManualTriggerSensorEntity):
for res in result.mappings(): for res in result.mappings():
_LOGGER.debug("Query %s result in %s", self._query, res.items()) _LOGGER.debug("Query %s result in %s", self._query, res.items())
try:
data = res[self._column_name] data = res[self._column_name]
except NoSuchColumnError as err:
_LOGGER.error(
"Column %s not found in query result for query %s: %s",
self._column_name,
self._query,
redact_credentials(str(err)),
)
sess.rollback()
sess.close()
return
for key, value in res.items(): for key, value in res.items():
if isinstance(value, decimal.Decimal): if isinstance(value, decimal.Decimal):
value = float(value) value = float(value)
@@ -398,6 +451,8 @@ class SQLSensor(ManualTriggerSensorEntity):
if data is not None and isinstance(data, (bytes, bytearray)): if data is not None and isinstance(data, (bytes, bytearray)):
data = f"0x{data.hex()}" data = f"0x{data.hex()}"
print(data, self._template)
if data is not None and self._template is not None: if data is not None and self._template is not None:
variables = self._template_variables_with_value(data) variables = self._template_variables_with_value(data)
if self._render_availability_template(variables): if self._render_availability_template(variables):
@@ -406,6 +461,7 @@ class SQLSensor(ManualTriggerSensorEntity):
) )
self._set_native_value_with_possible_timestamp(_value) self._set_native_value_with_possible_timestamp(_value)
self._process_manual_data(variables) self._process_manual_data(variables)
print(self._attr_native_value)
else: else:
self._attr_native_value = data self._attr_native_value = data

View File

@@ -0,0 +1,80 @@
# serializer version: 1
# name: test_config_flow_preview[incorrect_column]
dict({
'attributes': dict({
'device_class': 'data_size',
'friendly_name': 'Get Value',
'state_class': 'total',
'unit_of_measurement': 'MiB',
}),
'state': 'unknown',
})
# ---
# name: test_config_flow_preview[missing_column]
dict({
'attributes': dict({
'friendly_name': 'Get Value',
}),
'state': 'unknown',
})
# ---
# name: test_config_flow_preview[success]
dict({
'attributes': dict({
'device_class': 'data_size',
'friendly_name': 'Get Value',
'state_class': 'total',
'unit_of_measurement': 'MiB',
'value': 5,
}),
'state': '5',
})
# ---
# name: test_config_flow_preview[with_value_template]
dict({
'attributes': dict({
'device_class': 'data_size',
'friendly_name': 'Get Value',
'state_class': 'total',
'unit_of_measurement': 'MiB',
'value': 5,
}),
'state': '5',
})
# ---
# name: test_config_flow_preview[with_value_template_invalid]
dict({
'attributes': dict({
'device_class': 'data_size',
'friendly_name': 'Get Value',
'state_class': 'total',
'unit_of_measurement': 'MiB',
'value': 5,
}),
'state': '5',
})
# ---
# name: test_config_flow_preview_no_database
dict({
'attributes': dict({
'device_class': 'data_size',
'friendly_name': 'Get Value',
'state_class': 'total',
'unit_of_measurement': 'MiB',
'value': 5,
}),
'state': '5',
})
# ---
# name: test_options_flow_preview
dict({
'attributes': dict({
'device_class': 'data_size',
'friendly_name': 'Get Value',
'state_class': 'total',
'unit_of_measurement': 'MiB',
'value': 6,
}),
'state': '6',
})
# ---

View File

@@ -5,12 +5,24 @@ from __future__ import annotations
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import pytest
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from syrupy.assertion import SnapshotAssertion
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.recorder import Recorder from homeassistant.components.recorder import CONF_DB_URL, Recorder
from homeassistant.components.sensor import SensorDeviceClass, SensorStateClass from homeassistant.components.sensor import (
from homeassistant.components.sql.const import DOMAIN CONF_STATE_CLASS,
SensorDeviceClass,
SensorStateClass,
)
from homeassistant.components.sql.const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
from homeassistant.const import (
CONF_DEVICE_CLASS,
CONF_NAME,
CONF_UNIT_OF_MEASUREMENT,
CONF_VALUE_TEMPLATE,
)
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
@@ -35,6 +47,7 @@ from . import (
) )
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.typing import WebSocketGenerator
async def test_form(recorder_mock: Recorder, hass: HomeAssistant) -> None: async def test_form(recorder_mock: Recorder, hass: HomeAssistant) -> None:
@@ -795,3 +808,249 @@ async def test_device_state_class(recorder_mock: Recorder, hass: HomeAssistant)
"column": "value", "column": "value",
"unit_of_measurement": "MiB", "unit_of_measurement": "MiB",
} }
@pytest.mark.parametrize(
"user_input",
[
(
{
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
}
),
(
{
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
}
),
(
{
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
}
),
(
{
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_VALUE_TEMPLATE: "{{ value }}",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
}
),
(
{
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_VALUE_TEMPLATE: "{{ value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
}
),
],
ids=(
"success",
"incorrect_column",
"missing_column",
"with_value_template",
"with_value_template_invalid",
),
)
async def test_config_flow_preview(
recorder_mock: Recorder,
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
user_input: str,
snapshot: SnapshotAssertion,
) -> None:
"""Test the config flow preview."""
client = await hass_ws_client(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "user"
assert result["errors"] == {}
assert result["preview"] == "sql"
await client.send_json_auto_id(
{
"type": "sql/start_preview",
"flow_id": result["flow_id"],
"flow_type": "config_flow",
"user_input": user_input,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] is None
msg = await client.receive_json()
assert msg["event"] == snapshot
assert len(hass.states.async_all()) == 0
async def test_config_flow_preview_no_database(
recorder_mock: Recorder,
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test the config flow preview with no database."""
client = await hass_ws_client(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "user"
assert result["errors"] == {}
assert result["preview"] == "sql"
await client.send_json_auto_id(
{
"type": "sql/start_preview",
"flow_id": result["flow_id"],
"flow_type": "config_flow",
"user_input": {
CONF_DB_URL: "sqlite://homeassistant.local",
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
},
}
)
# msg = await client.receive_json()
# assert msg["success"]
# assert msg["result"] is None
# msg = await client.receive_json()
# assert msg["event"] == snapshot
# assert len(hass.states.async_all()) == 0
await hass.async_block_till_done(wait_background_tasks=True)
assert False
async def test_options_flow_preview(
recorder_mock: Recorder,
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test the options flow preview."""
client = await hass_ws_client(hass)
# Setup the config entry
config_entry = MockConfigEntry(
data={},
domain=DOMAIN,
options={
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
},
title="Get Value",
)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == FlowResultType.FORM
assert result["preview"] == "sql"
await client.send_json_auto_id(
{
"type": "sql/start_preview",
"flow_id": result["flow_id"],
"flow_type": "options_flow",
"user_input": {
CONF_QUERY: "SELECT 6 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
},
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] is None
msg = await client.receive_json()
assert msg["event"] == snapshot
assert len(hass.states.async_all()) == 1
async def test_options_flow_sensor_preview_config_entry_removed(
recorder_mock: Recorder, hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> None:
"""Test the option flow preview where the config entry is removed."""
client = await hass_ws_client(hass)
# Setup the config entry
config_entry = MockConfigEntry(
data={},
domain=DOMAIN,
options={
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
},
title="Get Value",
)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == FlowResultType.FORM
assert result["preview"] == "sql"
await hass.config_entries.async_remove(config_entry.entry_id)
await client.send_json_auto_id(
{
"type": "sql/start_preview",
"flow_id": result["flow_id"],
"flow_type": "options_flow",
"user_input": {
CONF_QUERY: "SELECT 6 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
},
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "Config entry not found",
}