This commit is contained in:
J. Nick Koston
2023-04-10 11:49:32 -10:00
parent 055667d403
commit 79078670dd

View File

@@ -15,6 +15,7 @@ from sqlalchemy import (
Subquery,
and_,
func,
lambda_stmt,
select,
union_all,
)
@@ -76,7 +77,7 @@ _FIELD_MAP = {
def _stmt_and_join_attributes(
no_attributes: bool, include_last_changed: bool = True
no_attributes: bool, include_last_changed: bool
) -> Select:
"""Return the statement and if StateAttributes should be joined."""
# If no_attributes was requested we do the query
@@ -213,7 +214,6 @@ def get_significant_states_with_session(
raise NotImplementedError("Filters are no longer supported")
if not entity_ids:
raise ValueError("entity_ids must be provided")
metadata_ids: list[int] | None = None
entity_id_to_metadata_id: dict[str, int | None] | None = None
metadata_ids_in_significant_domains: list[int] = []
instance = recorder.get_instance(hass)
@@ -221,8 +221,9 @@ def get_significant_states_with_session(
entity_id_to_metadata_id := instance.states_meta_manager.get_many(
entity_ids, session, False
)
) or not (metadata_ids := extract_metadata_ids(entity_id_to_metadata_id)):
) or not (possible_metadata_ids := extract_metadata_ids(entity_id_to_metadata_id)):
return {}
metadata_ids = possible_metadata_ids
if significant_changes_only:
metadata_ids_in_significant_domains = [
metadata_id
@@ -241,15 +242,23 @@ def get_significant_states_with_session(
include_start_time_state = False
start_time_ts = dt_util.utc_to_timestamp(start_time)
end_time_ts = datetime_to_timestamp_or_none(end_time)
stmt = _significant_states_stmt(
start_time_ts,
end_time_ts,
metadata_ids,
metadata_ids_in_significant_domains,
significant_changes_only,
no_attributes,
include_start_time_state,
run_start_ts,
stmt = lambda_stmt(
lambda: _significant_states_stmt(
start_time_ts,
end_time_ts,
metadata_ids,
metadata_ids_in_significant_domains,
significant_changes_only,
no_attributes,
include_start_time_state,
run_start_ts,
),
track_on=[
bool(end_time_ts),
bool(significant_changes_only),
bool(no_attributes),
bool(include_start_time_state),
],
)
states = execute_stmt_lambda_element(session, stmt, None, end_time)
return _sorted_states_to_dict(
@@ -298,6 +307,7 @@ def _state_changed_during_period_stmt(
start_time_ts: float,
end_time_ts: float | None,
metadata_id: int,
metadata_ids: list[int],
no_attributes: bool,
descending: bool,
limit: int | None,
@@ -305,7 +315,7 @@ def _state_changed_during_period_stmt(
run_start_ts: float | None,
) -> Select | CompoundSelect:
stmt = (
_stmt_and_join_attributes(no_attributes, include_last_changed=False)
_stmt_and_join_attributes(no_attributes, False)
.filter(
(
(States.last_changed_ts == States.last_updated_ts)
@@ -333,7 +343,7 @@ def _state_changed_during_period_stmt(
union_all(
_select_from_subquery(
_get_start_time_state_stmt(
run_start_ts, start_time_ts, [metadata_id], no_attributes
run_start_ts, start_time_ts, metadata_ids, no_attributes
).subquery(),
no_attributes,
),
@@ -359,12 +369,14 @@ def state_changes_during_period(
entity_ids = [entity_id.lower()]
with session_scope(hass=hass, read_only=True) as session:
metadata_id: int | None = None
instance = recorder.get_instance(hass)
if not (
metadata_id := instance.states_meta_manager.get(entity_id, session, False)
possible_metadata_id := instance.states_meta_manager.get(
entity_id, session, False
)
):
return {}
metadata_id = possible_metadata_id
entity_id_to_metadata_id: dict[str, int | None] = {entity_id: metadata_id}
run_start_ts: float | None = None
if include_start_time_state:
@@ -377,15 +389,26 @@ def state_changes_during_period(
include_start_time_state = False
start_time_ts = dt_util.utc_to_timestamp(start_time)
end_time_ts = datetime_to_timestamp_or_none(end_time)
stmt = _state_changed_during_period_stmt(
start_time_ts,
end_time_ts,
metadata_id,
no_attributes,
descending,
limit,
include_start_time_state,
run_start_ts,
metadata_ids = [metadata_id]
stmt = lambda_stmt(
lambda: _state_changed_during_period_stmt(
start_time_ts,
end_time_ts,
metadata_id,
metadata_ids,
no_attributes,
descending,
limit,
include_start_time_state,
run_start_ts,
),
track_on=[
bool(end_time_ts),
no_attributes,
descending,
bool(limit),
bool(include_start_time_state),
],
)
states = execute_stmt_lambda_element(session, stmt, None, end_time)
return cast(
@@ -399,7 +422,7 @@ def state_changes_during_period(
def _get_last_state_changes_stmt(number_of_states: int, metadata_id: int) -> Select:
stmt = _stmt_and_join_attributes(False, include_last_changed=False)
stmt = _stmt_and_join_attributes(False, False)
if number_of_states == 1:
stmt = stmt.join(
(
@@ -452,11 +475,17 @@ def get_last_state_changes(
with session_scope(hass=hass, read_only=True) as session:
instance = recorder.get_instance(hass)
if not (
metadata_id := instance.states_meta_manager.get(entity_id, session, False)
possible_metadata_id := instance.states_meta_manager.get(
entity_id, session, False
)
):
return {}
metadata_id = possible_metadata_id
entity_id_to_metadata_id: dict[str, int | None] = {entity_id_lower: metadata_id}
stmt = _get_last_state_changes_stmt(number_of_states, metadata_id)
stmt = lambda_stmt(
lambda: _get_last_state_changes_stmt(number_of_states, metadata_id),
track_on=[number_of_states == 1],
)
states = list(execute_stmt_lambda_element(session, stmt))
return cast(
MutableMapping[str, list[State]],
@@ -475,7 +504,7 @@ def _get_states_for_entities_stmt(
no_attributes: bool,
) -> Select:
"""Baked query to get states for specific entities."""
stmt = _stmt_and_join_attributes(no_attributes, include_last_changed=True)
stmt = _stmt_and_join_attributes(no_attributes, True)
# We got an include-list of entities, accelerate the query by filtering already
# in the inner query.
stmt = stmt.join(
@@ -527,7 +556,7 @@ def _get_start_time_state_stmt(
run_start_ts: float,
epoch_time: float,
metadata_ids: list[int],
no_attributes: bool = False,
no_attributes: bool,
) -> Select:
"""Return the states at a specific point in time."""