This commit is contained in:
J. Nick Koston
2023-04-10 13:35:02 -10:00
parent 7a68da8243
commit 8b5c650ce3
2 changed files with 76 additions and 45 deletions

View File

@@ -16,11 +16,11 @@ from sqlalchemy import (
and_, and_,
func, func,
lambda_stmt, lambda_stmt,
literal,
select, select,
union_all, union_all,
) )
from sqlalchemy.engine.row import Row from sqlalchemy.engine.row import Row
from sqlalchemy.orm.properties import MappedColumn
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from homeassistant.const import COMPRESSED_STATE_LAST_UPDATED, COMPRESSED_STATE_STATE from homeassistant.const import COMPRESSED_STATE_LAST_UPDATED, COMPRESSED_STATE_STATE
@@ -45,28 +45,13 @@ from .const import (
STATE_KEY, STATE_KEY,
) )
_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED = (
States.metadata_id,
States.state,
States.last_updated_ts,
)
_QUERY_STATE_NO_ATTR = (
*_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED,
States.last_changed_ts,
)
_QUERY_ATTRIBUTES = (
# Remove States.attributes once all attributes are in StateAttributes.shared_attrs
States.attributes,
StateAttributes.shared_attrs,
)
_QUERY_STATES = (*_QUERY_STATE_NO_ATTR, *_QUERY_ATTRIBUTES)
_QUERY_STATES_NO_LAST_CHANGED = (
*_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED,
*_QUERY_ATTRIBUTES,
)
_FIELD_MAP = { _FIELD_MAP = {
cast(MappedColumn, field).name: idx "metadata_id": 0,
for idx, field in enumerate(_QUERY_STATE_NO_ATTR) "state": 1,
"last_updated_ts": 2,
"last_changed_ts": 3,
"attributes": 4,
"shared_attrs": 5,
} }
@@ -74,17 +59,25 @@ def _stmt_and_join_attributes(
no_attributes: bool, include_last_changed: bool no_attributes: bool, include_last_changed: bool
) -> Select: ) -> Select:
"""Return the statement and if StateAttributes should be joined.""" """Return the statement and if StateAttributes should be joined."""
# If no_attributes was requested we do the query _select = select(States.metadata_id, States.state, States.last_updated_ts)
# without the attributes fields and do not join the
# state_attributes table
if no_attributes:
if include_last_changed:
return select(*_QUERY_STATE_NO_ATTR)
return select(*_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED)
if include_last_changed: if include_last_changed:
return select(*_QUERY_STATES) _select = _select.add_columns(States.last_changed_ts)
return select(*_QUERY_STATES_NO_LAST_CHANGED) if not no_attributes:
_select = _select.add_columns(States.attributes, StateAttributes.shared_attrs)
return _select
def _stmt_and_join_attributes_for_epoch_time(
no_attributes: bool, include_last_changed: bool, epoch_time: float
) -> Select:
"""Return the statement and if StateAttributes should be joined."""
_select = select(States.metadata_id, States.state)
_select = _select.add_columns(literal(value=None).label("last_updated_ts"))
if include_last_changed:
_select = _select.add_columns(literal(value=None).label("last_changed_ts"))
if not no_attributes:
_select = _select.add_columns(States.attributes, StateAttributes.shared_attrs)
return _select
def _select_from_subquery( def _select_from_subquery(
@@ -167,12 +160,14 @@ def _significant_states_stmt(
stmt = stmt.order_by(States.metadata_id, States.last_updated_ts) stmt = stmt.order_by(States.metadata_id, States.last_updated_ts)
if not include_start_time_state or not run_start_ts: if not include_start_time_state or not run_start_ts:
return stmt return stmt
start_time_ts_str = str(start_time_ts)
return _select_from_subquery( return _select_from_subquery(
union_all( union_all(
_select_from_subquery( _select_from_subquery(
_get_start_time_state_stmt( _get_start_time_state_stmt(
run_start_ts, run_start_ts,
start_time_ts, start_time_ts,
start_time_ts_str,
single_metadata_id, single_metadata_id,
metadata_ids, metadata_ids,
no_attributes, no_attributes,
@@ -263,6 +258,7 @@ def get_significant_states_with_session(
) )
return _sorted_states_to_dict( return _sorted_states_to_dict(
execute_stmt_lambda_element(session, stmt, None, end_time), execute_stmt_lambda_element(session, stmt, None, end_time),
start_time_ts if include_start_time_state else None,
entity_ids, entity_ids,
entity_id_to_metadata_id, entity_id_to_metadata_id,
minimal_response, minimal_response,
@@ -305,6 +301,7 @@ def get_full_significant_states_with_session(
def _state_changed_during_period_stmt( def _state_changed_during_period_stmt(
start_time_ts: float, start_time_ts: float,
start_time_ts_str: str,
end_time_ts: float | None, end_time_ts: float | None,
single_metadata_id: int, single_metadata_id: int,
no_attributes: bool, no_attributes: bool,
@@ -342,7 +339,11 @@ def _state_changed_during_period_stmt(
union_all( union_all(
_select_from_subquery( _select_from_subquery(
_get_single_entity_start_time_stmt( _get_single_entity_start_time_stmt(
start_time_ts, single_metadata_id, no_attributes, False start_time_ts,
start_time_ts_str,
single_metadata_id,
no_attributes,
False,
).subquery(), ).subquery(),
no_attributes, no_attributes,
False, False,
@@ -387,10 +388,12 @@ def state_changes_during_period(
): ):
include_start_time_state = False include_start_time_state = False
start_time_ts = dt_util.utc_to_timestamp(start_time) start_time_ts = dt_util.utc_to_timestamp(start_time)
start_time_ts_str = str(start_time_ts)
end_time_ts = datetime_to_timestamp_or_none(end_time) end_time_ts = datetime_to_timestamp_or_none(end_time)
stmt = lambda_stmt( stmt = lambda_stmt(
lambda: _state_changed_during_period_stmt( lambda: _state_changed_during_period_stmt(
start_time_ts, start_time_ts,
start_time_ts_str,
end_time_ts, end_time_ts,
single_metadata_id, single_metadata_id,
no_attributes, no_attributes,
@@ -411,6 +414,7 @@ def state_changes_during_period(
MutableMapping[str, list[State]], MutableMapping[str, list[State]],
_sorted_states_to_dict( _sorted_states_to_dict(
execute_stmt_lambda_element(session, stmt, None, end_time), execute_stmt_lambda_element(session, stmt, None, end_time),
start_time_ts if include_start_time_state else None,
entity_ids, entity_ids,
entity_id_to_metadata_id, entity_id_to_metadata_id,
), ),
@@ -487,6 +491,7 @@ def get_last_state_changes(
MutableMapping[str, list[State]], MutableMapping[str, list[State]],
_sorted_states_to_dict( _sorted_states_to_dict(
reversed(states), reversed(states),
None,
entity_ids, entity_ids,
entity_id_to_metadata_id, entity_id_to_metadata_id,
), ),
@@ -496,6 +501,7 @@ def get_last_state_changes(
def _get_start_time_state_for_entities_stmt( def _get_start_time_state_for_entities_stmt(
run_start_ts: float, run_start_ts: float,
epoch_time: float, epoch_time: float,
epoch_time_string: str,
metadata_ids: list[int], metadata_ids: list[int],
no_attributes: bool, no_attributes: bool,
include_last_changed: bool, include_last_changed: bool,
@@ -503,7 +509,9 @@ def _get_start_time_state_for_entities_stmt(
"""Baked query to get states for specific entities.""" """Baked query to get states for specific entities."""
# We got an include-list of entities, accelerate the query by filtering already # We got an include-list of entities, accelerate the query by filtering already
# in the inner query. # in the inner query.
stmt = _stmt_and_join_attributes(no_attributes, include_last_changed).join( stmt = _stmt_and_join_attributes_for_epoch_time(
no_attributes, include_last_changed, epoch_time_string
).join(
( (
most_recent_states_for_entities_by_date := ( most_recent_states_for_entities_by_date := (
select( select(
@@ -552,6 +560,7 @@ def _get_run_start_ts_for_utc_point_in_time(
def _get_start_time_state_stmt( def _get_start_time_state_stmt(
run_start_ts: float, run_start_ts: float,
epoch_time: float, epoch_time: float,
epoch_time_str: str,
single_metadata_id: int | None, single_metadata_id: int | None,
metadata_ids: list[int], metadata_ids: list[int],
no_attributes: bool, no_attributes: bool,
@@ -562,22 +571,37 @@ def _get_start_time_state_stmt(
# Use an entirely different (and extremely fast) query if we only # Use an entirely different (and extremely fast) query if we only
# have a single entity id # have a single entity id
return _get_single_entity_start_time_stmt( return _get_single_entity_start_time_stmt(
epoch_time, single_metadata_id, no_attributes, include_last_changed epoch_time,
epoch_time_str,
single_metadata_id,
no_attributes,
include_last_changed,
) )
# We have more than one entity to look at so we need to do a query on states # We have more than one entity to look at so we need to do a query on states
# since the last recorder run started. # since the last recorder run started.
return _get_start_time_state_for_entities_stmt( return _get_start_time_state_for_entities_stmt(
run_start_ts, epoch_time, metadata_ids, no_attributes, include_last_changed run_start_ts,
epoch_time,
epoch_time_str,
metadata_ids,
no_attributes,
include_last_changed,
) )
def _get_single_entity_start_time_stmt( def _get_single_entity_start_time_stmt(
epoch_time: float, metadata_id: int, no_attributes: bool, include_last_changed: bool epoch_time: float,
epoch_time_str: str,
metadata_id: int,
no_attributes: bool,
include_last_changed: bool,
) -> Select: ) -> Select:
# Use an entirely different (and extremely fast) query if we only # Use an entirely different (and extremely fast) query if we only
# have a single entity id # have a single entity id
stmt = ( stmt = (
_stmt_and_join_attributes(no_attributes, include_last_changed) _stmt_and_join_attributes_for_epoch_time(
no_attributes, include_last_changed, epoch_time_str
)
.filter( .filter(
States.last_updated_ts < epoch_time, States.last_updated_ts < epoch_time,
States.metadata_id == metadata_id, States.metadata_id == metadata_id,
@@ -594,6 +618,7 @@ def _get_single_entity_start_time_stmt(
def _sorted_states_to_dict( def _sorted_states_to_dict(
states: Iterable[Row], states: Iterable[Row],
start_time_ts: float | None,
entity_ids: list[str], entity_ids: list[str],
entity_id_to_metadata_id: dict[str, int | None], entity_id_to_metadata_id: dict[str, int | None],
minimal_response: bool = False, minimal_response: bool = False,
@@ -611,7 +636,9 @@ def _sorted_states_to_dict(
axis correctly. axis correctly.
""" """
field_map = _FIELD_MAP field_map = _FIELD_MAP
state_class: Callable[[Row, dict[str, dict[str, Any]]], State | dict[str, Any]] state_class: Callable[
[Row, dict[str, dict[str, Any]], float | None], State | dict[str, Any]
]
if compressed_state_format: if compressed_state_format:
state_class = row_to_compressed_state state_class = row_to_compressed_state
attr_time = COMPRESSED_STATE_LAST_UPDATED attr_time = COMPRESSED_STATE_LAST_UPDATED
@@ -655,7 +682,7 @@ def _sorted_states_to_dict(
or split_entity_id(entity_id)[0] in NEED_ATTRIBUTE_DOMAINS or split_entity_id(entity_id)[0] in NEED_ATTRIBUTE_DOMAINS
): ):
ent_results.extend( ent_results.extend(
state_class(db_state, attr_cache, entity_id=entity_id) # type: ignore[call-arg] state_class(db_state, attr_cache, start_time_ts, entity_id=entity_id) # type: ignore[call-arg]
for db_state in group for db_state in group
) )
continue continue
@@ -669,7 +696,7 @@ def _sorted_states_to_dict(
continue continue
prev_state = first_state.state prev_state = first_state.state
ent_results.append( ent_results.append(
state_class(first_state, attr_cache, entity_id=entity_id) # type: ignore[call-arg] state_class(first_state, attr_cache, start_time_ts, entity_id=entity_id) # type: ignore[call-arg]
) )
state_idx = field_map["state"] state_idx = field_map["state"]
@@ -687,7 +714,7 @@ def _sorted_states_to_dict(
ent_results.extend( ent_results.extend(
{ {
attr_state: (prev_state := state), attr_state: (prev_state := state),
attr_time: row[last_updated_ts_idx], attr_time: row[last_updated_ts_idx] or start_time_ts,
} }
for row in group for row in group
if (state := row[state_idx]) != prev_state if (state := row[state_idx]) != prev_state
@@ -699,7 +726,9 @@ def _sorted_states_to_dict(
ent_results.extend( ent_results.extend(
{ {
attr_state: (prev_state := state), # noqa: F841 attr_state: (prev_state := state), # noqa: F841
attr_time: _utc_from_timestamp(row[last_updated_ts_idx]).isoformat(), attr_time: _utc_from_timestamp(
row[last_updated_ts_idx] or start_time_ts
).isoformat(),
} }
for row in group for row in group
if (state := row[state_idx]) != prev_state if (state := row[state_idx]) != prev_state

View File

@@ -51,6 +51,7 @@ class LazyState(State):
self, self,
row: Row, row: Row,
attr_cache: dict[str, dict[str, Any]], attr_cache: dict[str, dict[str, Any]],
start_time_ts: float | None,
entity_id: str | None = None, entity_id: str | None = None,
) -> None: ) -> None:
"""Init the lazy state.""" """Init the lazy state."""
@@ -58,7 +59,7 @@ class LazyState(State):
self.entity_id = entity_id or self._row.entity_id self.entity_id = entity_id or self._row.entity_id
self.state = self._row.state or "" self.state = self._row.state or ""
self._attributes: dict[str, Any] | None = None self._attributes: dict[str, Any] | None = None
self._last_updated_ts: float | None = self._row.last_updated_ts self._last_updated_ts: float | None = self._row.last_updated_ts or start_time_ts
self._last_changed_ts: float | None = ( self._last_changed_ts: float | None = (
getattr(self._row, "last_changed_ts", None) or self._last_updated_ts getattr(self._row, "last_changed_ts", None) or self._last_updated_ts
) )
@@ -135,6 +136,7 @@ class LazyState(State):
def row_to_compressed_state( def row_to_compressed_state(
row: Row, row: Row,
attr_cache: dict[str, dict[str, Any]], attr_cache: dict[str, dict[str, Any]],
start_time_ts: float | None,
entity_id: str | None = None, entity_id: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Convert a database row to a compressed state schema 31 and later.""" """Convert a database row to a compressed state schema 31 and later."""
@@ -142,7 +144,7 @@ def row_to_compressed_state(
COMPRESSED_STATE_STATE: row.state, COMPRESSED_STATE_STATE: row.state,
COMPRESSED_STATE_ATTRIBUTES: decode_attributes_from_row(row, attr_cache), COMPRESSED_STATE_ATTRIBUTES: decode_attributes_from_row(row, attr_cache),
} }
row_last_updated_ts: float = row.last_updated_ts row_last_updated_ts: float = row.last_updated_ts or start_time_ts
comp_state[COMPRESSED_STATE_LAST_UPDATED] = row_last_updated_ts comp_state[COMPRESSED_STATE_LAST_UPDATED] = row_last_updated_ts
if ( if (
(row_last_changed_ts := getattr(row, "last_changed_ts", None)) (row_last_changed_ts := getattr(row, "last_changed_ts", None))