mirror of
https://github.com/home-assistant/core.git
synced 2025-09-05 21:01:37 +02:00
Improve sql config flow
This commit is contained in:
@@ -6,7 +6,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.engine import Result
|
||||
from sqlalchemy.engine import Engine, Result
|
||||
from sqlalchemy.exc import MultipleResultsFound, NoSuchColumnError, SQLAlchemyError
|
||||
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
||||
import sqlparse
|
||||
@@ -32,9 +32,10 @@ from homeassistant.const import (
|
||||
CONF_VALUE_TEMPLATE,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.data_entry_flow import section
|
||||
from homeassistant.helpers import selector
|
||||
|
||||
from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
|
||||
from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
|
||||
from .util import resolve_db_url
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -42,40 +43,38 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
OPTIONS_SCHEMA: vol.Schema = vol.Schema(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_DB_URL,
|
||||
): selector.TextSelector(),
|
||||
vol.Required(
|
||||
CONF_COLUMN_NAME,
|
||||
): selector.TextSelector(),
|
||||
vol.Required(
|
||||
CONF_QUERY,
|
||||
): selector.TextSelector(selector.TextSelectorConfig(multiline=True)),
|
||||
vol.Optional(
|
||||
CONF_UNIT_OF_MEASUREMENT,
|
||||
): selector.TextSelector(),
|
||||
vol.Optional(
|
||||
CONF_VALUE_TEMPLATE,
|
||||
): selector.TemplateSelector(),
|
||||
vol.Optional(CONF_DEVICE_CLASS): selector.SelectSelector(
|
||||
selector.SelectSelectorConfig(
|
||||
options=[
|
||||
cls.value
|
||||
for cls in SensorDeviceClass
|
||||
if cls != SensorDeviceClass.ENUM
|
||||
],
|
||||
mode=selector.SelectSelectorMode.DROPDOWN,
|
||||
translation_key="device_class",
|
||||
sort=True,
|
||||
)
|
||||
vol.Required(CONF_COLUMN_NAME): selector.TextSelector(),
|
||||
vol.Required(CONF_QUERY): selector.TextSelector(
|
||||
selector.TextSelectorConfig(multiline=True)
|
||||
),
|
||||
vol.Optional(CONF_STATE_CLASS): selector.SelectSelector(
|
||||
selector.SelectSelectorConfig(
|
||||
options=[cls.value for cls in SensorStateClass],
|
||||
mode=selector.SelectSelectorMode.DROPDOWN,
|
||||
translation_key="state_class",
|
||||
sort=True,
|
||||
)
|
||||
vol.Required(CONF_ADVANCED_OPTIONS): section(
|
||||
vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_VALUE_TEMPLATE): selector.TemplateSelector(),
|
||||
vol.Optional(CONF_UNIT_OF_MEASUREMENT): selector.TextSelector(),
|
||||
vol.Optional(CONF_DEVICE_CLASS): selector.SelectSelector(
|
||||
selector.SelectSelectorConfig(
|
||||
options=[
|
||||
cls.value
|
||||
for cls in SensorDeviceClass
|
||||
if cls != SensorDeviceClass.ENUM
|
||||
],
|
||||
mode=selector.SelectSelectorMode.DROPDOWN,
|
||||
translation_key="device_class",
|
||||
sort=True,
|
||||
)
|
||||
),
|
||||
vol.Optional(CONF_STATE_CLASS): selector.SelectSelector(
|
||||
selector.SelectSelectorConfig(
|
||||
options=[cls.value for cls in SensorStateClass],
|
||||
mode=selector.SelectSelectorMode.DROPDOWN,
|
||||
translation_key="state_class",
|
||||
sort=True,
|
||||
)
|
||||
),
|
||||
}
|
||||
),
|
||||
{"collapsed": True},
|
||||
),
|
||||
}
|
||||
)
|
||||
@@ -83,8 +82,9 @@ OPTIONS_SCHEMA: vol.Schema = vol.Schema(
|
||||
CONFIG_SCHEMA: vol.Schema = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_NAME, default="Select SQL Query"): selector.TextSelector(),
|
||||
vol.Optional(CONF_DB_URL): selector.TextSelector(),
|
||||
}
|
||||
).extend(OPTIONS_SCHEMA.schema)
|
||||
)
|
||||
|
||||
|
||||
def validate_sql_select(value: str) -> str:
|
||||
@@ -99,6 +99,31 @@ def validate_sql_select(value: str) -> str:
|
||||
return str(query[0])
|
||||
|
||||
|
||||
def validate_db_connection(db_url: str) -> bool:
|
||||
"""Validate db connection."""
|
||||
|
||||
engine: Engine | None = None
|
||||
sess: Session | None = None
|
||||
try:
|
||||
engine = sqlalchemy.create_engine(db_url, future=True)
|
||||
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
|
||||
sess = sessmaker()
|
||||
sess.execute(sqlalchemy.text("select 1 as value"))
|
||||
except SQLAlchemyError as error:
|
||||
_LOGGER.debug("Execution error %s", error)
|
||||
if sess:
|
||||
sess.close()
|
||||
if engine:
|
||||
engine.dispose()
|
||||
raise
|
||||
|
||||
if sess:
|
||||
sess.close()
|
||||
engine.dispose()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_query(db_url: str, query: str, column: str) -> bool:
|
||||
"""Validate SQL query."""
|
||||
|
||||
@@ -138,6 +163,8 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
|
||||
VERSION = 1
|
||||
|
||||
data: dict[str, Any]
|
||||
|
||||
@staticmethod
|
||||
@callback
|
||||
def async_get_options_flow(
|
||||
@@ -151,17 +178,46 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle the user step."""
|
||||
errors = {}
|
||||
description_placeholders = {}
|
||||
|
||||
if user_input is not None:
|
||||
db_url = user_input.get(CONF_DB_URL)
|
||||
|
||||
try:
|
||||
db_url_for_validation = resolve_db_url(self.hass, db_url)
|
||||
await self.hass.async_add_executor_job(
|
||||
validate_db_connection, db_url_for_validation
|
||||
)
|
||||
except SQLAlchemyError:
|
||||
errors["db_url"] = "db_url_invalid"
|
||||
|
||||
if not errors:
|
||||
self.data = {CONF_NAME: user_input[CONF_NAME]}
|
||||
if db_url and db_url_for_validation != get_instance(self.hass).db_url:
|
||||
self.data[CONF_DB_URL] = db_url
|
||||
return await self.async_step_options()
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=self.add_suggested_values_to_schema(CONFIG_SCHEMA, user_input),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_options(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle the user step."""
|
||||
errors = {}
|
||||
description_placeholders = {}
|
||||
|
||||
if user_input is not None:
|
||||
query = user_input[CONF_QUERY]
|
||||
column = user_input[CONF_COLUMN_NAME]
|
||||
db_url_for_validation = None
|
||||
|
||||
try:
|
||||
query = validate_sql_select(query)
|
||||
db_url_for_validation = resolve_db_url(self.hass, db_url)
|
||||
db_url_for_validation = resolve_db_url(
|
||||
self.hass, self.data.get(CONF_DB_URL)
|
||||
)
|
||||
await self.hass.async_add_executor_job(
|
||||
validate_query, db_url_for_validation, query, column
|
||||
)
|
||||
@@ -178,32 +234,22 @@ class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
_LOGGER.debug("Invalid query: %s", err)
|
||||
errors["query"] = "query_invalid"
|
||||
|
||||
options = {
|
||||
CONF_QUERY: query,
|
||||
CONF_COLUMN_NAME: column,
|
||||
CONF_NAME: user_input[CONF_NAME],
|
||||
}
|
||||
if uom := user_input.get(CONF_UNIT_OF_MEASUREMENT):
|
||||
options[CONF_UNIT_OF_MEASUREMENT] = uom
|
||||
if value_template := user_input.get(CONF_VALUE_TEMPLATE):
|
||||
options[CONF_VALUE_TEMPLATE] = value_template
|
||||
if device_class := user_input.get(CONF_DEVICE_CLASS):
|
||||
options[CONF_DEVICE_CLASS] = device_class
|
||||
if state_class := user_input.get(CONF_STATE_CLASS):
|
||||
options[CONF_STATE_CLASS] = state_class
|
||||
if db_url_for_validation != get_instance(self.hass).db_url:
|
||||
options[CONF_DB_URL] = db_url_for_validation
|
||||
for k, v in user_input[CONF_ADVANCED_OPTIONS].items():
|
||||
if not v:
|
||||
user_input[CONF_ADVANCED_OPTIONS].pop(k)
|
||||
|
||||
if not errors:
|
||||
name = self.data[CONF_NAME]
|
||||
self.data.pop(CONF_NAME)
|
||||
return self.async_create_entry(
|
||||
title=user_input[CONF_NAME],
|
||||
data={},
|
||||
options=options,
|
||||
title=name,
|
||||
data=self.data,
|
||||
options=user_input,
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=self.add_suggested_values_to_schema(CONFIG_SCHEMA, user_input),
|
||||
step_id="options",
|
||||
data_schema=self.add_suggested_values_to_schema(OPTIONS_SCHEMA, user_input),
|
||||
errors=errors,
|
||||
description_placeholders=description_placeholders,
|
||||
)
|
||||
@@ -220,10 +266,9 @@ class SQLOptionsFlowHandler(OptionsFlowWithReload):
|
||||
description_placeholders = {}
|
||||
|
||||
if user_input is not None:
|
||||
db_url = user_input.get(CONF_DB_URL)
|
||||
db_url = self.config_entry.data.get(CONF_DB_URL)
|
||||
query = user_input[CONF_QUERY]
|
||||
column = user_input[CONF_COLUMN_NAME]
|
||||
name = self.config_entry.options.get(CONF_NAME, self.config_entry.title)
|
||||
|
||||
try:
|
||||
query = validate_sql_select(query)
|
||||
@@ -252,24 +297,12 @@ class SQLOptionsFlowHandler(OptionsFlowWithReload):
|
||||
recorder_db,
|
||||
)
|
||||
|
||||
options = {
|
||||
CONF_QUERY: query,
|
||||
CONF_COLUMN_NAME: column,
|
||||
CONF_NAME: name,
|
||||
}
|
||||
if uom := user_input.get(CONF_UNIT_OF_MEASUREMENT):
|
||||
options[CONF_UNIT_OF_MEASUREMENT] = uom
|
||||
if value_template := user_input.get(CONF_VALUE_TEMPLATE):
|
||||
options[CONF_VALUE_TEMPLATE] = value_template
|
||||
if device_class := user_input.get(CONF_DEVICE_CLASS):
|
||||
options[CONF_DEVICE_CLASS] = device_class
|
||||
if state_class := user_input.get(CONF_STATE_CLASS):
|
||||
options[CONF_STATE_CLASS] = state_class
|
||||
if db_url_for_validation != get_instance(self.hass).db_url:
|
||||
options[CONF_DB_URL] = db_url_for_validation
|
||||
for k, v in user_input[CONF_ADVANCED_OPTIONS].items():
|
||||
if not v:
|
||||
user_input[CONF_ADVANCED_OPTIONS].pop(k)
|
||||
|
||||
return self.async_create_entry(
|
||||
data=options,
|
||||
data=user_input,
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
|
@@ -9,4 +9,5 @@ PLATFORMS = [Platform.SENSOR]
|
||||
|
||||
CONF_COLUMN_NAME = "column"
|
||||
CONF_QUERY = "query"
|
||||
CONF_ADVANCED_OPTIONS = "advanced_options"
|
||||
DB_URL_RE = re.compile("//.*:.*@")
|
||||
|
@@ -10,7 +10,12 @@ from homeassistant.components.sensor import (
|
||||
SensorDeviceClass,
|
||||
SensorStateClass,
|
||||
)
|
||||
from homeassistant.components.sql.const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
|
||||
from homeassistant.components.sql.const import (
|
||||
CONF_ADVANCED_OPTIONS,
|
||||
CONF_COLUMN_NAME,
|
||||
CONF_QUERY,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.config_entries import SOURCE_USER
|
||||
from homeassistant.const import (
|
||||
CONF_DEVICE_CLASS,
|
||||
@@ -30,140 +35,167 @@ from homeassistant.helpers.trigger_template_entity import (
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
ENTRY_CONFIG = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "SELECT 5 as value",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
|
||||
CONF_STATE_CLASS: SensorStateClass.TOTAL,
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_DEVICE_CLASS: SensorDeviceClass.DATA_SIZE,
|
||||
CONF_STATE_CLASS: SensorStateClass.TOTAL,
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_WITH_VALUE_TEMPLATE = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "SELECT 5 as value",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_VALUE_TEMPLATE: "{{ value }}",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_VALUE_TEMPLATE: "{{ value }}",
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "SELECT 5 FROM as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY_2 = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "SELECT5 FROM as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY_3 = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: ";;",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY_OPT = {
|
||||
CONF_QUERY: "SELECT 5 FROM as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY_2_OPT = {
|
||||
CONF_QUERY: "SELECT5 FROM as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_INVALID_QUERY_3_OPT = {
|
||||
CONF_QUERY: ";;",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_QUERY_READ_ONLY_CTE = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;",
|
||||
CONF_COLUMN_NAME: "state",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "UPDATE states SET state = 999999 WHERE state_id = 11125",
|
||||
CONF_COLUMN_NAME: "state",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_QUERY_READ_ONLY_CTE_OPT = {
|
||||
CONF_QUERY: "WITH test AS (SELECT 1 AS row_num, 10 AS state) SELECT state FROM test WHERE row_num = 1 LIMIT 1;",
|
||||
CONF_COLUMN_NAME: "state",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY_OPT = {
|
||||
CONF_QUERY: "UPDATE 5 as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_QUERY_NO_READ_ONLY_CTE_OPT = {
|
||||
CONF_QUERY: "WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_MULTIPLE_QUERIES = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "SELECT 5 as state; UPDATE states SET state = 10;",
|
||||
CONF_COLUMN_NAME: "state",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_MULTIPLE_QUERIES_OPT = {
|
||||
CONF_QUERY: "SELECT 5 as state; UPDATE states SET state = 10;",
|
||||
CONF_COLUMN_NAME: "state",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
ENTRY_CONFIG_INVALID_COLUMN_NAME = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "SELECT 5 as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_INVALID_COLUMN_NAME_OPT = {
|
||||
CONF_QUERY: "SELECT 5 as value",
|
||||
CONF_COLUMN_NAME: "size",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
ENTRY_CONFIG_NO_RESULTS = {
|
||||
CONF_NAME: "Get Value",
|
||||
CONF_QUERY: "SELECT kalle as value from no_table;",
|
||||
CONF_COLUMN_NAME: "value",
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
CONF_ADVANCED_OPTIONS: {
|
||||
CONF_UNIT_OF_MEASUREMENT: "MiB",
|
||||
},
|
||||
}
|
||||
|
||||
YAML_CONFIG = {
|
||||
|
41
tests/components/sql/conftest.py
Normal file
41
tests/components/sql/conftest.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Fixtures for the SQL integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_setup_entry() -> Generator[AsyncMock]:
|
||||
"""Override async_setup_entry."""
|
||||
with patch(
|
||||
"homeassistant.components.sql.async_setup_entry", return_value=True
|
||||
) as mock_setup_entry:
|
||||
yield mock_setup_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def create_db(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
) -> str:
|
||||
"""Test the SQL sensor with a query that returns no value."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db_path_str = f"sqlite:///{db_path}"
|
||||
|
||||
def make_test_db():
|
||||
"""Create a test database."""
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute("CREATE TABLE users (value INTEGER)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
await hass.async_add_executor_job(make_test_db)
|
||||
return db_path_str
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user