diff --git a/homeassistant/components/sql/config_flow.py b/homeassistant/components/sql/config_flow.py index 37a6f9ef104..3d7e56fd35d 100644 --- a/homeassistant/components/sql/config_flow.py +++ b/homeassistant/components/sql/config_flow.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Mapping import logging from typing import Any @@ -13,9 +14,11 @@ import sqlparse from sqlparse.exceptions import SQLParseError import voluptuous as vol +from homeassistant.components import websocket_api from homeassistant.components.recorder import CONF_DB_URL, get_instance from homeassistant.components.sensor import ( CONF_STATE_CLASS, + DOMAIN as SENSOR_DOMAIN, SensorDeviceClass, SensorStateClass, ) @@ -31,10 +34,15 @@ from homeassistant.const import ( CONF_UNIT_OF_MEASUREMENT, CONF_VALUE_TEMPLATE, ) -from homeassistant.core import callback +from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.helpers import selector +from homeassistant.helpers.entity_platform import PlatformData +from homeassistant.helpers.template import Template +from homeassistant.helpers.trigger_template_entity import ValueTemplate from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN +from .sensor import TRIGGER_ENTITY_OPTIONS, SQLSensor, get_db_connection from .util import resolve_db_url _LOGGER = logging.getLogger(__name__) @@ -138,6 +146,11 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN): VERSION = 1 + @staticmethod + async def async_setup_preview(hass: HomeAssistant) -> None: + """Set up preview WS API.""" + websocket_api.async_register_command(hass, ws_start_preview) + @staticmethod @callback def async_get_options_flow( @@ -206,6 +219,7 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN): data_schema=self.add_suggested_values_to_schema(CONFIG_SCHEMA, user_input), errors=errors, description_placeholders=description_placeholders, + preview="sql", ) @@ -279,4 +293,103 @@ class SQLOptionsFlowHandler(OptionsFlowWithReload): ), errors=errors, description_placeholders=description_placeholders, + preview="sql", ) + + @staticmethod + async def async_setup_preview(hass: HomeAssistant) -> None: + """Set up preview WS API.""" + websocket_api.async_register_command(hass, ws_start_preview) + + +@websocket_api.websocket_command( + { + vol.Required("type"): "sql/start_preview", + vol.Required("flow_id"): str, + vol.Required("flow_type"): vol.Any("config_flow", "options_flow"), + vol.Required("user_input"): dict, + } +) +@websocket_api.async_response +async def ws_start_preview( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Generate a preview.""" + + if msg["flow_type"] == "config_flow": + flow_status = hass.config_entries.flow.async_get(msg["flow_id"]) + flow_sets = hass.config_entries.flow._handler_progress_index.get( # noqa: SLF001 + flow_status["handler"] + ) + assert flow_sets + config_entry = hass.config_entries.async_get_entry(flow_status["handler"]) + name = msg["user_input"][CONF_NAME] + + else: + flow_status = hass.config_entries.options.async_get(msg["flow_id"]) + config_entry = hass.config_entries.async_get_entry(flow_status["handler"]) + if not config_entry: + raise HomeAssistantError("Config entry not found") + name = config_entry.options[CONF_NAME] + + @callback + def async_preview_updated(state: str, attributes: Mapping[str, Any]) -> None: + """Forward config entry state events to websocket.""" + connection.send_message( + websocket_api.event_message( + msg["id"], {"attributes": attributes, "state": state} + ) + ) + + db_url = resolve_db_url(hass, msg["user_input"].get(CONF_DB_URL)) + + if ( + db_connection := await get_db_connection( + hass, + db_url, + ) + ) is None: + return # Missing test + sessmaker = db_connection[0] + use_database_executor = db_connection[1] + + name_template = Template(name, hass) + trigger_entity_config = {CONF_NAME: name_template} + for key in TRIGGER_ENTITY_OPTIONS: + if key not in msg["user_input"]: + continue + trigger_entity_config[key] = msg["user_input"][key] + + query_str: str = msg["user_input"].get(CONF_QUERY) + template: str | None = msg["user_input"].get(CONF_VALUE_TEMPLATE) + column_name: str = msg["user_input"].get(CONF_COLUMN_NAME) + + value_template: ValueTemplate | None = None + if template is not None: + try: + value_template = ValueTemplate(template, hass) + value_template.ensure_valid() + except TemplateError: + value_template = None + + preview_entity = SQLSensor( + trigger_entity_config=trigger_entity_config, + sessmaker=sessmaker, + query=query_str, + column=column_name, + value_template=value_template, + yaml=False, + use_database_executor=use_database_executor, + ) + preview_entity.hass = hass + + # Create PlatformData, needed for name translations + platform_data = PlatformData(hass=hass, domain=SENSOR_DOMAIN, platform_name=DOMAIN) + await platform_data.async_load_translations() + + connection.send_result(msg["id"]) + connection.subscriptions[msg["id"]] = await preview_entity.async_start_preview( + async_preview_updated + ) diff --git a/homeassistant/components/sql/sensor.py b/homeassistant/components/sql/sensor.py index 8c0ba81d6d2..bae0e6df710 100644 --- a/homeassistant/components/sql/sensor.py +++ b/homeassistant/components/sql/sensor.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Callable, Mapping from datetime import date import decimal import logging @@ -10,7 +11,7 @@ from typing import Any import sqlalchemy from sqlalchemy import lambda_stmt from sqlalchemy.engine import Result -from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.exc import NoSuchColumnError, SQLAlchemyError from sqlalchemy.orm import Session, scoped_session, sessionmaker from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.util import LRUCache @@ -32,7 +33,7 @@ from homeassistant.const import ( EVENT_HOMEASSISTANT_STOP, MATCH_ALL, ) -from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback from homeassistant.exceptions import TemplateError from homeassistant.helpers import issue_registry as ir from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo @@ -175,18 +176,11 @@ def _async_get_or_init_domain_data(hass: HomeAssistant) -> SQLData: return sql_data -async def async_setup_sensor( +async def get_db_connection( hass: HomeAssistant, - trigger_entity_config: ConfigType, - query_str: str, - column_name: str, - value_template: ValueTemplate | None, - unique_id: str | None, db_url: str, - yaml: bool, - async_add_entities: AddEntitiesCallback | AddConfigEntryEntitiesCallback, -) -> None: - """Set up the SQL sensor.""" +) -> tuple[scoped_session, bool, bool] | None: + """Get a database connection.""" try: instance = get_instance(hass) except KeyError: # No recorder loaded @@ -213,7 +207,33 @@ async def async_setup_sensor( ): sql_data.session_makers_by_db_url[db_url] = sessmaker else: + return None + return sessmaker, use_database_executor, uses_recorder_db + + +async def async_setup_sensor( + hass: HomeAssistant, + trigger_entity_config: ConfigType, + query_str: str, + column_name: str, + value_template: ValueTemplate | None, + unique_id: str | None, + db_url: str, + yaml: bool, + async_add_entities: AddEntitiesCallback | AddConfigEntryEntitiesCallback, +) -> None: + """Set up the SQL sensor.""" + + if ( + db_connection := await get_db_connection( + hass, + db_url, + ) + ) is None: return + sessmaker = db_connection[0] + use_database_executor = db_connection[1] + uses_recorder_db = db_connection[2] upper_query = query_str.upper() if uses_recorder_db: @@ -339,6 +359,7 @@ class SQLSensor(ManualTriggerSensorEntity): manufacturer="SQL", name=self._rendered.get(CONF_NAME), ) + self._preview_callback: Callable[[str, Mapping[str, Any]], None] | None = None @property def name(self) -> str | None: @@ -357,12 +378,32 @@ class SQLSensor(ManualTriggerSensorEntity): """Return extra attributes.""" return dict(self._attr_extra_state_attributes) + async def async_start_preview( + self, + preview_callback: Callable[[str, Mapping[str, Any]], None], + ) -> CALLBACK_TYPE: + """Render a preview.""" + # abort early if there is needed data missing + if not self._query or not self._column_name: + self._attr_available = False + calculated_state = self._async_calculate_state() + preview_callback(calculated_state.state, calculated_state.attributes) + return self._call_on_remove_callbacks + + self._preview_callback = preview_callback + + await self.async_update() + return self._call_on_remove_callbacks + async def async_update(self) -> None: """Retrieve sensor data from the query using the right executor.""" if self._use_database_executor: await get_instance(self.hass).async_add_executor_job(self._update) else: await self.hass.async_add_executor_job(self._update) + if self._preview_callback: + calculated_state = self._async_calculate_state() + self._preview_callback(calculated_state.state, calculated_state.attributes) def _update(self) -> None: """Retrieve sensor data from the query.""" @@ -384,7 +425,19 @@ class SQLSensor(ManualTriggerSensorEntity): for res in result.mappings(): _LOGGER.debug("Query %s result in %s", self._query, res.items()) - data = res[self._column_name] + try: + data = res[self._column_name] + except NoSuchColumnError as err: + _LOGGER.error( + "Column %s not found in query result for query %s: %s", + self._column_name, + self._query, + redact_credentials(str(err)), + ) + sess.rollback() + sess.close() + return + for key, value in res.items(): if isinstance(value, decimal.Decimal): value = float(value) @@ -398,6 +451,8 @@ class SQLSensor(ManualTriggerSensorEntity): if data is not None and isinstance(data, (bytes, bytearray)): data = f"0x{data.hex()}" + print(data, self._template) + if data is not None and self._template is not None: variables = self._template_variables_with_value(data) if self._render_availability_template(variables): @@ -406,6 +461,7 @@ class SQLSensor(ManualTriggerSensorEntity): ) self._set_native_value_with_possible_timestamp(_value) self._process_manual_data(variables) + print(self._attr_native_value) else: self._attr_native_value = data diff --git a/tests/components/sql/snapshots/test_config_flow.ambr b/tests/components/sql/snapshots/test_config_flow.ambr new file mode 100644 index 00000000000..3af4b71000a --- /dev/null +++ b/tests/components/sql/snapshots/test_config_flow.ambr @@ -0,0 +1,80 @@ +# serializer version: 1 +# name: test_config_flow_preview[incorrect_column] + dict({ + 'attributes': dict({ + 'device_class': 'data_size', + 'friendly_name': 'Get Value', + 'state_class': 'total', + 'unit_of_measurement': 'MiB', + }), + 'state': 'unknown', + }) +# --- +# name: test_config_flow_preview[missing_column] + dict({ + 'attributes': dict({ + 'friendly_name': 'Get Value', + }), + 'state': 'unknown', + }) +# --- +# name: test_config_flow_preview[success] + dict({ + 'attributes': dict({ + 'device_class': 'data_size', + 'friendly_name': 'Get Value', + 'state_class': 'total', + 'unit_of_measurement': 'MiB', + 'value': 5, + }), + 'state': '5', + }) +# --- +# name: test_config_flow_preview[with_value_template] + dict({ + 'attributes': dict({ + 'device_class': 'data_size', + 'friendly_name': 'Get Value', + 'state_class': 'total', + 'unit_of_measurement': 'MiB', + 'value': 5, + }), + 'state': '5', + }) +# --- +# name: test_config_flow_preview[with_value_template_invalid] + dict({ + 'attributes': dict({ + 'device_class': 'data_size', + 'friendly_name': 'Get Value', + 'state_class': 'total', + 'unit_of_measurement': 'MiB', + 'value': 5, + }), + 'state': '5', + }) +# --- +# name: test_config_flow_preview_no_database + dict({ + 'attributes': dict({ + 'device_class': 'data_size', + 'friendly_name': 'Get Value', + 'state_class': 'total', + 'unit_of_measurement': 'MiB', + 'value': 5, + }), + 'state': '5', + }) +# --- +# name: test_options_flow_preview + dict({ + 'attributes': dict({ + 'device_class': 'data_size', + 'friendly_name': 'Get Value', + 'state_class': 'total', + 'unit_of_measurement': 'MiB', + 'value': 6, + }), + 'state': '6', + }) +# --- diff --git a/tests/components/sql/test_config_flow.py b/tests/components/sql/test_config_flow.py index 3f2400c0a32..b75adcf14be 100644 --- a/tests/components/sql/test_config_flow.py +++ b/tests/components/sql/test_config_flow.py @@ -5,12 +5,24 @@ from __future__ import annotations from pathlib import Path from unittest.mock import patch +import pytest from sqlalchemy.exc import SQLAlchemyError +from syrupy.assertion import SnapshotAssertion from homeassistant import config_entries -from homeassistant.components.recorder import Recorder -from homeassistant.components.sensor import SensorDeviceClass, SensorStateClass -from homeassistant.components.sql.const import DOMAIN +from homeassistant.components.recorder import CONF_DB_URL, Recorder +from homeassistant.components.sensor import ( + CONF_STATE_CLASS, + SensorDeviceClass, + SensorStateClass, +) +from homeassistant.components.sql.const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN +from homeassistant.const import ( + CONF_DEVICE_CLASS, + CONF_NAME, + CONF_UNIT_OF_MEASUREMENT, + CONF_VALUE_TEMPLATE, +) from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType @@ -35,6 +47,7 @@ from . import ( ) from tests.common import MockConfigEntry +from tests.typing import WebSocketGenerator async def test_form(recorder_mock: Recorder, hass: HomeAssistant) -> None: @@ -795,3 +808,249 @@ async def test_device_state_class(recorder_mock: Recorder, hass: HomeAssistant) "column": "value", "unit_of_measurement": "MiB", } + + +@pytest.mark.parametrize( + "user_input", + [ + ( + { + CONF_NAME: "Get Value", + CONF_QUERY: "SELECT 5 as value", + CONF_COLUMN_NAME: "value", + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE, + CONF_STATE_CLASS: SensorStateClass.TOTAL, + } + ), + ( + { + CONF_NAME: "Get Value", + CONF_QUERY: "SELECT 5 as value", + CONF_COLUMN_NAME: "state", + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE, + CONF_STATE_CLASS: SensorStateClass.TOTAL, + } + ), + ( + { + CONF_NAME: "Get Value", + CONF_QUERY: "SELECT 5 as value", + } + ), + ( + { + CONF_NAME: "Get Value", + CONF_QUERY: "SELECT 5 as value", + CONF_COLUMN_NAME: "value", + CONF_VALUE_TEMPLATE: "{{ value }}", + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE, + CONF_STATE_CLASS: SensorStateClass.TOTAL, + } + ), + ( + { + CONF_NAME: "Get Value", + CONF_QUERY: "SELECT 5 as value", + CONF_COLUMN_NAME: "value", + CONF_VALUE_TEMPLATE: "{{ value", + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE, + CONF_STATE_CLASS: SensorStateClass.TOTAL, + } + ), + ], + ids=( + "success", + "incorrect_column", + "missing_column", + "with_value_template", + "with_value_template_invalid", + ), +) +async def test_config_flow_preview( + recorder_mock: Recorder, + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + user_input: str, + snapshot: SnapshotAssertion, +) -> None: + """Test the config flow preview.""" + client = await hass_ws_client(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "user" + assert result["errors"] == {} + assert result["preview"] == "sql" + + await client.send_json_auto_id( + { + "type": "sql/start_preview", + "flow_id": result["flow_id"], + "flow_type": "config_flow", + "user_input": user_input, + } + ) + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] is None + + msg = await client.receive_json() + assert msg["event"] == snapshot + assert len(hass.states.async_all()) == 0 + + +async def test_config_flow_preview_no_database( + recorder_mock: Recorder, + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, +) -> None: + """Test the config flow preview with no database.""" + client = await hass_ws_client(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "user" + assert result["errors"] == {} + assert result["preview"] == "sql" + + await client.send_json_auto_id( + { + "type": "sql/start_preview", + "flow_id": result["flow_id"], + "flow_type": "config_flow", + "user_input": { + CONF_DB_URL: "sqlite://homeassistant.local", + CONF_NAME: "Get Value", + CONF_QUERY: "SELECT 5 as value", + CONF_COLUMN_NAME: "value", + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE, + CONF_STATE_CLASS: SensorStateClass.TOTAL, + }, + } + ) + # msg = await client.receive_json() + # assert msg["success"] + # assert msg["result"] is None + + # msg = await client.receive_json() + # assert msg["event"] == snapshot + # assert len(hass.states.async_all()) == 0 + await hass.async_block_till_done(wait_background_tasks=True) + assert False + + +async def test_options_flow_preview( + recorder_mock: Recorder, + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, +) -> None: + """Test the options flow preview.""" + client = await hass_ws_client(hass) + + # Setup the config entry + config_entry = MockConfigEntry( + data={}, + domain=DOMAIN, + options={ + CONF_NAME: "Get Value", + CONF_QUERY: "SELECT 5 as value", + CONF_COLUMN_NAME: "value", + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE, + CONF_STATE_CLASS: SensorStateClass.TOTAL, + }, + title="Get Value", + ) + config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + result = await hass.config_entries.options.async_init(config_entry.entry_id) + assert result["type"] == FlowResultType.FORM + assert result["preview"] == "sql" + + await client.send_json_auto_id( + { + "type": "sql/start_preview", + "flow_id": result["flow_id"], + "flow_type": "options_flow", + "user_input": { + CONF_QUERY: "SELECT 6 as value", + CONF_COLUMN_NAME: "value", + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE, + CONF_STATE_CLASS: SensorStateClass.TOTAL, + }, + } + ) + + msg = await client.receive_json() + assert msg["success"] + assert msg["result"] is None + + msg = await client.receive_json() + assert msg["event"] == snapshot + assert len(hass.states.async_all()) == 1 + + +async def test_options_flow_sensor_preview_config_entry_removed( + recorder_mock: Recorder, hass: HomeAssistant, hass_ws_client: WebSocketGenerator +) -> None: + """Test the option flow preview where the config entry is removed.""" + client = await hass_ws_client(hass) + + # Setup the config entry + config_entry = MockConfigEntry( + data={}, + domain=DOMAIN, + options={ + CONF_NAME: "Get Value", + CONF_QUERY: "SELECT 5 as value", + CONF_COLUMN_NAME: "value", + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE, + CONF_STATE_CLASS: SensorStateClass.TOTAL, + }, + title="Get Value", + ) + config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() + + result = await hass.config_entries.options.async_init(config_entry.entry_id) + assert result["type"] == FlowResultType.FORM + assert result["preview"] == "sql" + + await hass.config_entries.async_remove(config_entry.entry_id) + + await client.send_json_auto_id( + { + "type": "sql/start_preview", + "flow_id": result["flow_id"], + "flow_type": "options_flow", + "user_input": { + CONF_QUERY: "SELECT 6 as value", + CONF_COLUMN_NAME: "value", + CONF_UNIT_OF_MEASUREMENT: "MiB", + CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE, + CONF_STATE_CLASS: SensorStateClass.TOTAL, + }, + } + ) + msg = await client.receive_json() + assert not msg["success"] + assert msg["error"] == { + "code": "home_assistant_error", + "message": "Config entry not found", + }