Files

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

95 lines
2.6 KiB
Python
Raw Permalink Normal View History

"""Test the sql utils."""
from datetime import UTC, date, datetime
from decimal import Decimal
2025-09-28 10:03:13 -07:00
import pytest
import voluptuous as vol
from homeassistant.components.recorder import Recorder, get_instance
from homeassistant.components.sql.util import (
convert_value,
resolve_db_url,
validate_sql_select,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers.template import Template
async def test_resolve_db_url_when_none_configured(
recorder_mock: Recorder,
hass: HomeAssistant,
) -> None:
"""Test return recorder db_url if provided db_url is None."""
db_url = None
resolved_url = resolve_db_url(hass, db_url)
assert resolved_url == get_instance(hass).db_url
async def test_resolve_db_url_when_configured(hass: HomeAssistant) -> None:
"""Test return provided db_url if it's set."""
db_url = "mssql://"
resolved_url = resolve_db_url(hass, db_url)
assert resolved_url == db_url
2025-09-28 10:03:13 -07:00
@pytest.mark.parametrize(
("sql_query", "expected_error_message"),
[
(
"DROP TABLE *",
"SQL query must be of type SELECT",
2025-09-28 10:03:13 -07:00
),
(
"SELECT5 as value",
"SQL query is empty or unknown type",
2025-09-28 10:03:13 -07:00
),
(
";;",
"SQL query is empty or unknown type",
2025-09-28 10:03:13 -07:00
),
(
"UPDATE states SET state = 999999 WHERE state_id = 11125",
"SQL query must be of type SELECT",
2025-09-28 10:03:13 -07:00
),
(
"WITH test AS (SELECT state FROM states) UPDATE states SET states.state = test.state;",
"SQL query must be of type SELECT",
2025-09-28 10:03:13 -07:00
),
(
"SELECT 5 as value; UPDATE states SET state = 10;",
"Multiple SQL statements are not allowed",
2025-09-28 10:03:13 -07:00
),
],
)
async def test_invalid_sql_queries(
hass: HomeAssistant,
sql_query: str,
expected_error_message: str,
) -> None:
"""Test that various invalid or disallowed SQL queries raise the correct exception."""
with pytest.raises(vol.Invalid, match=expected_error_message):
validate_sql_select(Template(sql_query, hass))
@pytest.mark.parametrize(
("input", "expected_output"),
[
(Decimal("199.99"), 199.99),
(date(2023, 1, 15), "2023-01-15"),
(datetime(2023, 1, 15, 12, 30, 45, tzinfo=UTC), "2023-01-15T12:30:45+00:00"),
(b"\xde\xad\xbe\xef", "0xdeadbeef"),
("deadbeef", "deadbeef"),
(199.99, 199.99),
(69, 69),
],
)
async def test_value_conversion(
input: Decimal | date | datetime | bytes | str | float,
expected_output: str | float,
) -> None:
"""Test value conversion."""
assert convert_value(input) == expected_output