diff --git a/homeassistant/components/recorder/history/modern.py b/homeassistant/components/recorder/history/modern.py index 513f1870cc0..8452eac188b 100644 --- a/homeassistant/components/recorder/history/modern.py +++ b/homeassistant/components/recorder/history/modern.py @@ -16,11 +16,11 @@ from sqlalchemy import ( and_, func, lambda_stmt, + literal, select, union_all, ) from sqlalchemy.engine.row import Row -from sqlalchemy.orm.properties import MappedColumn from sqlalchemy.orm.session import Session from homeassistant.const import COMPRESSED_STATE_LAST_UPDATED, COMPRESSED_STATE_STATE @@ -45,28 +45,13 @@ from .const import ( STATE_KEY, ) -_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED = ( - States.metadata_id, - States.state, - States.last_updated_ts, -) -_QUERY_STATE_NO_ATTR = ( - *_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED, - States.last_changed_ts, -) -_QUERY_ATTRIBUTES = ( - # Remove States.attributes once all attributes are in StateAttributes.shared_attrs - States.attributes, - StateAttributes.shared_attrs, -) -_QUERY_STATES = (*_QUERY_STATE_NO_ATTR, *_QUERY_ATTRIBUTES) -_QUERY_STATES_NO_LAST_CHANGED = ( - *_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED, - *_QUERY_ATTRIBUTES, -) _FIELD_MAP = { - cast(MappedColumn, field).name: idx - for idx, field in enumerate(_QUERY_STATE_NO_ATTR) + "metadata_id": 0, + "state": 1, + "last_updated_ts": 2, + "last_changed_ts": 3, + "attributes": 4, + "shared_attrs": 5, } @@ -74,17 +59,25 @@ def _stmt_and_join_attributes( no_attributes: bool, include_last_changed: bool ) -> Select: """Return the statement and if StateAttributes should be joined.""" - # If no_attributes was requested we do the query - # without the attributes fields and do not join the - # state_attributes table - if no_attributes: - if include_last_changed: - return select(*_QUERY_STATE_NO_ATTR) - return select(*_QUERY_STATE_NO_ATTR_NO_LAST_CHANGED) - + _select = select(States.metadata_id, States.state, States.last_updated_ts) if include_last_changed: - return select(*_QUERY_STATES) - return select(*_QUERY_STATES_NO_LAST_CHANGED) + _select = _select.add_columns(States.last_changed_ts) + if not no_attributes: + _select = _select.add_columns(States.attributes, StateAttributes.shared_attrs) + return _select + + +def _stmt_and_join_attributes_for_epoch_time( + no_attributes: bool, include_last_changed: bool, epoch_time: float +) -> Select: + """Return the statement and if StateAttributes should be joined.""" + _select = select(States.metadata_id, States.state) + _select = _select.add_columns(literal(value=None).label("last_updated_ts")) + if include_last_changed: + _select = _select.add_columns(literal(value=None).label("last_changed_ts")) + if not no_attributes: + _select = _select.add_columns(States.attributes, StateAttributes.shared_attrs) + return _select def _select_from_subquery( @@ -167,12 +160,14 @@ def _significant_states_stmt( stmt = stmt.order_by(States.metadata_id, States.last_updated_ts) if not include_start_time_state or not run_start_ts: return stmt + start_time_ts_str = str(start_time_ts) return _select_from_subquery( union_all( _select_from_subquery( _get_start_time_state_stmt( run_start_ts, start_time_ts, + start_time_ts_str, single_metadata_id, metadata_ids, no_attributes, @@ -263,6 +258,7 @@ def get_significant_states_with_session( ) return _sorted_states_to_dict( execute_stmt_lambda_element(session, stmt, None, end_time), + start_time_ts if include_start_time_state else None, entity_ids, entity_id_to_metadata_id, minimal_response, @@ -305,6 +301,7 @@ def get_full_significant_states_with_session( def _state_changed_during_period_stmt( start_time_ts: float, + start_time_ts_str: str, end_time_ts: float | None, single_metadata_id: int, no_attributes: bool, @@ -342,7 +339,11 @@ def _state_changed_during_period_stmt( union_all( _select_from_subquery( _get_single_entity_start_time_stmt( - start_time_ts, single_metadata_id, no_attributes, False + start_time_ts, + start_time_ts_str, + single_metadata_id, + no_attributes, + False, ).subquery(), no_attributes, False, @@ -387,10 +388,12 @@ def state_changes_during_period( ): include_start_time_state = False start_time_ts = dt_util.utc_to_timestamp(start_time) + start_time_ts_str = str(start_time_ts) end_time_ts = datetime_to_timestamp_or_none(end_time) stmt = lambda_stmt( lambda: _state_changed_during_period_stmt( start_time_ts, + start_time_ts_str, end_time_ts, single_metadata_id, no_attributes, @@ -411,6 +414,7 @@ def state_changes_during_period( MutableMapping[str, list[State]], _sorted_states_to_dict( execute_stmt_lambda_element(session, stmt, None, end_time), + start_time_ts if include_start_time_state else None, entity_ids, entity_id_to_metadata_id, ), @@ -487,6 +491,7 @@ def get_last_state_changes( MutableMapping[str, list[State]], _sorted_states_to_dict( reversed(states), + None, entity_ids, entity_id_to_metadata_id, ), @@ -496,6 +501,7 @@ def get_last_state_changes( def _get_start_time_state_for_entities_stmt( run_start_ts: float, epoch_time: float, + epoch_time_string: str, metadata_ids: list[int], no_attributes: bool, include_last_changed: bool, @@ -503,7 +509,9 @@ def _get_start_time_state_for_entities_stmt( """Baked query to get states for specific entities.""" # We got an include-list of entities, accelerate the query by filtering already # in the inner query. - stmt = _stmt_and_join_attributes(no_attributes, include_last_changed).join( + stmt = _stmt_and_join_attributes_for_epoch_time( + no_attributes, include_last_changed, epoch_time_string + ).join( ( most_recent_states_for_entities_by_date := ( select( @@ -552,6 +560,7 @@ def _get_run_start_ts_for_utc_point_in_time( def _get_start_time_state_stmt( run_start_ts: float, epoch_time: float, + epoch_time_str: str, single_metadata_id: int | None, metadata_ids: list[int], no_attributes: bool, @@ -562,22 +571,37 @@ def _get_start_time_state_stmt( # Use an entirely different (and extremely fast) query if we only # have a single entity id return _get_single_entity_start_time_stmt( - epoch_time, single_metadata_id, no_attributes, include_last_changed + epoch_time, + epoch_time_str, + single_metadata_id, + no_attributes, + include_last_changed, ) # We have more than one entity to look at so we need to do a query on states # since the last recorder run started. return _get_start_time_state_for_entities_stmt( - run_start_ts, epoch_time, metadata_ids, no_attributes, include_last_changed + run_start_ts, + epoch_time, + epoch_time_str, + metadata_ids, + no_attributes, + include_last_changed, ) def _get_single_entity_start_time_stmt( - epoch_time: float, metadata_id: int, no_attributes: bool, include_last_changed: bool + epoch_time: float, + epoch_time_str: str, + metadata_id: int, + no_attributes: bool, + include_last_changed: bool, ) -> Select: # Use an entirely different (and extremely fast) query if we only # have a single entity id stmt = ( - _stmt_and_join_attributes(no_attributes, include_last_changed) + _stmt_and_join_attributes_for_epoch_time( + no_attributes, include_last_changed, epoch_time_str + ) .filter( States.last_updated_ts < epoch_time, States.metadata_id == metadata_id, @@ -594,6 +618,7 @@ def _get_single_entity_start_time_stmt( def _sorted_states_to_dict( states: Iterable[Row], + start_time_ts: float | None, entity_ids: list[str], entity_id_to_metadata_id: dict[str, int | None], minimal_response: bool = False, @@ -611,7 +636,9 @@ def _sorted_states_to_dict( axis correctly. """ field_map = _FIELD_MAP - state_class: Callable[[Row, dict[str, dict[str, Any]]], State | dict[str, Any]] + state_class: Callable[ + [Row, dict[str, dict[str, Any]], float | None], State | dict[str, Any] + ] if compressed_state_format: state_class = row_to_compressed_state attr_time = COMPRESSED_STATE_LAST_UPDATED @@ -655,7 +682,7 @@ def _sorted_states_to_dict( or split_entity_id(entity_id)[0] in NEED_ATTRIBUTE_DOMAINS ): ent_results.extend( - state_class(db_state, attr_cache, entity_id=entity_id) # type: ignore[call-arg] + state_class(db_state, attr_cache, start_time_ts, entity_id=entity_id) # type: ignore[call-arg] for db_state in group ) continue @@ -669,7 +696,7 @@ def _sorted_states_to_dict( continue prev_state = first_state.state ent_results.append( - state_class(first_state, attr_cache, entity_id=entity_id) # type: ignore[call-arg] + state_class(first_state, attr_cache, start_time_ts, entity_id=entity_id) # type: ignore[call-arg] ) state_idx = field_map["state"] @@ -687,7 +714,7 @@ def _sorted_states_to_dict( ent_results.extend( { attr_state: (prev_state := state), - attr_time: row[last_updated_ts_idx], + attr_time: row[last_updated_ts_idx] or start_time_ts, } for row in group if (state := row[state_idx]) != prev_state @@ -699,7 +726,9 @@ def _sorted_states_to_dict( ent_results.extend( { attr_state: (prev_state := state), # noqa: F841 - attr_time: _utc_from_timestamp(row[last_updated_ts_idx]).isoformat(), + attr_time: _utc_from_timestamp( + row[last_updated_ts_idx] or start_time_ts + ).isoformat(), } for row in group if (state := row[state_idx]) != prev_state diff --git a/homeassistant/components/recorder/models/state.py b/homeassistant/components/recorder/models/state.py index c443525701a..14c3444f641 100644 --- a/homeassistant/components/recorder/models/state.py +++ b/homeassistant/components/recorder/models/state.py @@ -51,6 +51,7 @@ class LazyState(State): self, row: Row, attr_cache: dict[str, dict[str, Any]], + start_time_ts: float | None, entity_id: str | None = None, ) -> None: """Init the lazy state.""" @@ -58,7 +59,7 @@ class LazyState(State): self.entity_id = entity_id or self._row.entity_id self.state = self._row.state or "" self._attributes: dict[str, Any] | None = None - self._last_updated_ts: float | None = self._row.last_updated_ts + self._last_updated_ts: float | None = self._row.last_updated_ts or start_time_ts self._last_changed_ts: float | None = ( getattr(self._row, "last_changed_ts", None) or self._last_updated_ts ) @@ -135,6 +136,7 @@ class LazyState(State): def row_to_compressed_state( row: Row, attr_cache: dict[str, dict[str, Any]], + start_time_ts: float | None, entity_id: str | None = None, ) -> dict[str, Any]: """Convert a database row to a compressed state schema 31 and later.""" @@ -142,7 +144,7 @@ def row_to_compressed_state( COMPRESSED_STATE_STATE: row.state, COMPRESSED_STATE_ATTRIBUTES: decode_attributes_from_row(row, attr_cache), } - row_last_updated_ts: float = row.last_updated_ts + row_last_updated_ts: float = row.last_updated_ts or start_time_ts comp_state[COMPRESSED_STATE_LAST_UPDATED] = row_last_updated_ts if ( (row_last_changed_ts := getattr(row, "last_changed_ts", None))