Add preview to sql config flow

This commit is contained in:
G Johansson
2025-08-09 15:26:32 +00:00
parent f8d3bc1b89
commit c3773a79ce
4 changed files with 525 additions and 17 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Mapping
import logging
from typing import Any
@@ -13,9 +14,11 @@ import sqlparse
from sqlparse.exceptions import SQLParseError
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components.recorder import CONF_DB_URL, get_instance
from homeassistant.components.sensor import (
CONF_STATE_CLASS,
DOMAIN as SENSOR_DOMAIN,
SensorDeviceClass,
SensorStateClass,
)
@@ -31,10 +34,15 @@ from homeassistant.const import (
CONF_UNIT_OF_MEASUREMENT,
CONF_VALUE_TEMPLATE,
)
from homeassistant.core import callback
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import selector
from homeassistant.helpers.entity_platform import PlatformData
from homeassistant.helpers.template import Template
from homeassistant.helpers.trigger_template_entity import ValueTemplate
from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
from .sensor import TRIGGER_ENTITY_OPTIONS, SQLSensor, get_db_connection
from .util import resolve_db_url
_LOGGER = logging.getLogger(__name__)
@@ -138,6 +146,11 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1
@staticmethod
async def async_setup_preview(hass: HomeAssistant) -> None:
"""Set up preview WS API."""
websocket_api.async_register_command(hass, ws_start_preview)
@staticmethod
@callback
def async_get_options_flow(
@@ -206,6 +219,7 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
data_schema=self.add_suggested_values_to_schema(CONFIG_SCHEMA, user_input),
errors=errors,
description_placeholders=description_placeholders,
preview="sql",
)
@@ -279,4 +293,103 @@ class SQLOptionsFlowHandler(OptionsFlowWithReload):
),
errors=errors,
description_placeholders=description_placeholders,
preview="sql",
)
@staticmethod
async def async_setup_preview(hass: HomeAssistant) -> None:
"""Set up preview WS API."""
websocket_api.async_register_command(hass, ws_start_preview)
@websocket_api.websocket_command(
{
vol.Required("type"): "sql/start_preview",
vol.Required("flow_id"): str,
vol.Required("flow_type"): vol.Any("config_flow", "options_flow"),
vol.Required("user_input"): dict,
}
)
@websocket_api.async_response
async def ws_start_preview(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Generate a preview."""
if msg["flow_type"] == "config_flow":
flow_status = hass.config_entries.flow.async_get(msg["flow_id"])
flow_sets = hass.config_entries.flow._handler_progress_index.get( # noqa: SLF001
flow_status["handler"]
)
assert flow_sets
config_entry = hass.config_entries.async_get_entry(flow_status["handler"])
name = msg["user_input"][CONF_NAME]
else:
flow_status = hass.config_entries.options.async_get(msg["flow_id"])
config_entry = hass.config_entries.async_get_entry(flow_status["handler"])
if not config_entry:
raise HomeAssistantError("Config entry not found")
name = config_entry.options[CONF_NAME]
@callback
def async_preview_updated(state: str, attributes: Mapping[str, Any]) -> None:
"""Forward config entry state events to websocket."""
connection.send_message(
websocket_api.event_message(
msg["id"], {"attributes": attributes, "state": state}
)
)
db_url = resolve_db_url(hass, msg["user_input"].get(CONF_DB_URL))
if (
db_connection := await get_db_connection(
hass,
db_url,
)
) is None:
return # Missing test
sessmaker = db_connection[0]
use_database_executor = db_connection[1]
name_template = Template(name, hass)
trigger_entity_config = {CONF_NAME: name_template}
for key in TRIGGER_ENTITY_OPTIONS:
if key not in msg["user_input"]:
continue
trigger_entity_config[key] = msg["user_input"][key]
query_str: str = msg["user_input"].get(CONF_QUERY)
template: str | None = msg["user_input"].get(CONF_VALUE_TEMPLATE)
column_name: str = msg["user_input"].get(CONF_COLUMN_NAME)
value_template: ValueTemplate | None = None
if template is not None:
try:
value_template = ValueTemplate(template, hass)
value_template.ensure_valid()
except TemplateError:
value_template = None
preview_entity = SQLSensor(
trigger_entity_config=trigger_entity_config,
sessmaker=sessmaker,
query=query_str,
column=column_name,
value_template=value_template,
yaml=False,
use_database_executor=use_database_executor,
)
preview_entity.hass = hass
# Create PlatformData, needed for name translations
platform_data = PlatformData(hass=hass, domain=SENSOR_DOMAIN, platform_name=DOMAIN)
await platform_data.async_load_translations()
connection.send_result(msg["id"])
connection.subscriptions[msg["id"]] = await preview_entity.async_start_preview(
async_preview_updated
)

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Callable, Mapping
from datetime import date
import decimal
import logging
@@ -10,7 +11,7 @@ from typing import Any
import sqlalchemy
from sqlalchemy import lambda_stmt
from sqlalchemy.engine import Result
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.exc import NoSuchColumnError, SQLAlchemyError
from sqlalchemy.orm import Session, scoped_session, sessionmaker
from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.util import LRUCache
@@ -32,7 +33,7 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP,
MATCH_ALL,
)
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import issue_registry as ir
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
@@ -175,18 +176,11 @@ def _async_get_or_init_domain_data(hass: HomeAssistant) -> SQLData:
return sql_data
async def async_setup_sensor(
async def get_db_connection(
hass: HomeAssistant,
trigger_entity_config: ConfigType,
query_str: str,
column_name: str,
value_template: ValueTemplate | None,
unique_id: str | None,
db_url: str,
yaml: bool,
async_add_entities: AddEntitiesCallback | AddConfigEntryEntitiesCallback,
) -> None:
"""Set up the SQL sensor."""
) -> tuple[scoped_session, bool, bool] | None:
"""Get a database connection."""
try:
instance = get_instance(hass)
except KeyError: # No recorder loaded
@@ -213,7 +207,33 @@ async def async_setup_sensor(
):
sql_data.session_makers_by_db_url[db_url] = sessmaker
else:
return None
return sessmaker, use_database_executor, uses_recorder_db
async def async_setup_sensor(
hass: HomeAssistant,
trigger_entity_config: ConfigType,
query_str: str,
column_name: str,
value_template: ValueTemplate | None,
unique_id: str | None,
db_url: str,
yaml: bool,
async_add_entities: AddEntitiesCallback | AddConfigEntryEntitiesCallback,
) -> None:
"""Set up the SQL sensor."""
if (
db_connection := await get_db_connection(
hass,
db_url,
)
) is None:
return
sessmaker = db_connection[0]
use_database_executor = db_connection[1]
uses_recorder_db = db_connection[2]
upper_query = query_str.upper()
if uses_recorder_db:
@@ -339,6 +359,7 @@ class SQLSensor(ManualTriggerSensorEntity):
manufacturer="SQL",
name=self._rendered.get(CONF_NAME),
)
self._preview_callback: Callable[[str, Mapping[str, Any]], None] | None = None
@property
def name(self) -> str | None:
@@ -357,12 +378,32 @@ class SQLSensor(ManualTriggerSensorEntity):
"""Return extra attributes."""
return dict(self._attr_extra_state_attributes)
async def async_start_preview(
self,
preview_callback: Callable[[str, Mapping[str, Any]], None],
) -> CALLBACK_TYPE:
"""Render a preview."""
# abort early if there is needed data missing
if not self._query or not self._column_name:
self._attr_available = False
calculated_state = self._async_calculate_state()
preview_callback(calculated_state.state, calculated_state.attributes)
return self._call_on_remove_callbacks
self._preview_callback = preview_callback
await self.async_update()
return self._call_on_remove_callbacks
async def async_update(self) -> None:
"""Retrieve sensor data from the query using the right executor."""
if self._use_database_executor:
await get_instance(self.hass).async_add_executor_job(self._update)
else:
await self.hass.async_add_executor_job(self._update)
if self._preview_callback:
calculated_state = self._async_calculate_state()
self._preview_callback(calculated_state.state, calculated_state.attributes)
def _update(self) -> None:
"""Retrieve sensor data from the query."""
@@ -384,7 +425,19 @@ class SQLSensor(ManualTriggerSensorEntity):
for res in result.mappings():
_LOGGER.debug("Query %s result in %s", self._query, res.items())
try:
data = res[self._column_name]
except NoSuchColumnError as err:
_LOGGER.error(
"Column %s not found in query result for query %s: %s",
self._column_name,
self._query,
redact_credentials(str(err)),
)
sess.rollback()
sess.close()
return
for key, value in res.items():
if isinstance(value, decimal.Decimal):
value = float(value)
@@ -398,6 +451,8 @@ class SQLSensor(ManualTriggerSensorEntity):
if data is not None and isinstance(data, (bytes, bytearray)):
data = f"0x{data.hex()}"
print(data, self._template)
if data is not None and self._template is not None:
variables = self._template_variables_with_value(data)
if self._render_availability_template(variables):
@@ -406,6 +461,7 @@ class SQLSensor(ManualTriggerSensorEntity):
)
self._set_native_value_with_possible_timestamp(_value)
self._process_manual_data(variables)
print(self._attr_native_value)
else:
self._attr_native_value = data

View File

@@ -0,0 +1,80 @@
# serializer version: 1
# name: test_config_flow_preview[incorrect_column]
dict({
'attributes': dict({
'device_class': 'data_size',
'friendly_name': 'Get Value',
'state_class': 'total',
'unit_of_measurement': 'MiB',
}),
'state': 'unknown',
})
# ---
# name: test_config_flow_preview[missing_column]
dict({
'attributes': dict({
'friendly_name': 'Get Value',
}),
'state': 'unknown',
})
# ---
# name: test_config_flow_preview[success]
dict({
'attributes': dict({
'device_class': 'data_size',
'friendly_name': 'Get Value',
'state_class': 'total',
'unit_of_measurement': 'MiB',
'value': 5,
}),
'state': '5',
})
# ---
# name: test_config_flow_preview[with_value_template]
dict({
'attributes': dict({
'device_class': 'data_size',
'friendly_name': 'Get Value',
'state_class': 'total',
'unit_of_measurement': 'MiB',
'value': 5,
}),
'state': '5',
})
# ---
# name: test_config_flow_preview[with_value_template_invalid]
dict({
'attributes': dict({
'device_class': 'data_size',
'friendly_name': 'Get Value',
'state_class': 'total',
'unit_of_measurement': 'MiB',
'value': 5,
}),
'state': '5',
})
# ---
# name: test_config_flow_preview_no_database
dict({
'attributes': dict({
'device_class': 'data_size',
'friendly_name': 'Get Value',
'state_class': 'total',
'unit_of_measurement': 'MiB',
'value': 5,
}),
'state': '5',
})
# ---
# name: test_options_flow_preview
dict({
'attributes': dict({
'device_class': 'data_size',
'friendly_name': 'Get Value',
'state_class': 'total',
'unit_of_measurement': 'MiB',
'value': 6,
}),
'state': '6',
})
# ---

View File

@@ -5,12 +5,24 @@ from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
import pytest
from sqlalchemy.exc import SQLAlchemyError
from syrupy.assertion import SnapshotAssertion
from homeassistant import config_entries
from homeassistant.components.recorder import Recorder
from homeassistant.components.sensor import SensorDeviceClass, SensorStateClass
from homeassistant.components.sql.const import DOMAIN
from homeassistant.components.recorder import CONF_DB_URL, Recorder
from homeassistant.components.sensor import (
CONF_STATE_CLASS,
SensorDeviceClass,
SensorStateClass,
)
from homeassistant.components.sql.const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
from homeassistant.const import (
CONF_DEVICE_CLASS,
CONF_NAME,
CONF_UNIT_OF_MEASUREMENT,
CONF_VALUE_TEMPLATE,
)
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
@@ -35,6 +47,7 @@ from . import (
)
from tests.common import MockConfigEntry
from tests.typing import WebSocketGenerator
async def test_form(recorder_mock: Recorder, hass: HomeAssistant) -> None:
@@ -795,3 +808,249 @@ async def test_device_state_class(recorder_mock: Recorder, hass: HomeAssistant)
"column": "value",
"unit_of_measurement": "MiB",
}
@pytest.mark.parametrize(
"user_input",
[
(
{
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
}
),
(
{
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "state",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
}
),
(
{
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
}
),
(
{
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_VALUE_TEMPLATE: "{{ value }}",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
}
),
(
{
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_VALUE_TEMPLATE: "{{ value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
}
),
],
ids=(
"success",
"incorrect_column",
"missing_column",
"with_value_template",
"with_value_template_invalid",
),
)
async def test_config_flow_preview(
recorder_mock: Recorder,
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
user_input: str,
snapshot: SnapshotAssertion,
) -> None:
"""Test the config flow preview."""
client = await hass_ws_client(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "user"
assert result["errors"] == {}
assert result["preview"] == "sql"
await client.send_json_auto_id(
{
"type": "sql/start_preview",
"flow_id": result["flow_id"],
"flow_type": "config_flow",
"user_input": user_input,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] is None
msg = await client.receive_json()
assert msg["event"] == snapshot
assert len(hass.states.async_all()) == 0
async def test_config_flow_preview_no_database(
recorder_mock: Recorder,
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test the config flow preview with no database."""
client = await hass_ws_client(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "user"
assert result["errors"] == {}
assert result["preview"] == "sql"
await client.send_json_auto_id(
{
"type": "sql/start_preview",
"flow_id": result["flow_id"],
"flow_type": "config_flow",
"user_input": {
CONF_DB_URL: "sqlite://homeassistant.local",
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
},
}
)
# msg = await client.receive_json()
# assert msg["success"]
# assert msg["result"] is None
# msg = await client.receive_json()
# assert msg["event"] == snapshot
# assert len(hass.states.async_all()) == 0
await hass.async_block_till_done(wait_background_tasks=True)
assert False
async def test_options_flow_preview(
recorder_mock: Recorder,
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
) -> None:
"""Test the options flow preview."""
client = await hass_ws_client(hass)
# Setup the config entry
config_entry = MockConfigEntry(
data={},
domain=DOMAIN,
options={
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
},
title="Get Value",
)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == FlowResultType.FORM
assert result["preview"] == "sql"
await client.send_json_auto_id(
{
"type": "sql/start_preview",
"flow_id": result["flow_id"],
"flow_type": "options_flow",
"user_input": {
CONF_QUERY: "SELECT 6 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
},
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] is None
msg = await client.receive_json()
assert msg["event"] == snapshot
assert len(hass.states.async_all()) == 1
async def test_options_flow_sensor_preview_config_entry_removed(
recorder_mock: Recorder, hass: HomeAssistant, hass_ws_client: WebSocketGenerator
) -> None:
"""Test the option flow preview where the config entry is removed."""
client = await hass_ws_client(hass)
# Setup the config entry
config_entry = MockConfigEntry(
data={},
domain=DOMAIN,
options={
CONF_NAME: "Get Value",
CONF_QUERY: "SELECT 5 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
},
title="Get Value",
)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == FlowResultType.FORM
assert result["preview"] == "sql"
await hass.config_entries.async_remove(config_entry.entry_id)
await client.send_json_auto_id(
{
"type": "sql/start_preview",
"flow_id": result["flow_id"],
"flow_type": "options_flow",
"user_input": {
CONF_QUERY: "SELECT 6 as value",
CONF_COLUMN_NAME: "value",
CONF_UNIT_OF_MEASUREMENT: "MiB",
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
CONF_STATE_CLASS: SensorStateClass.TOTAL,
},
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "home_assistant_error",
"message": "Config entry not found",
}