reduce overhead

This commit is contained in:
J. Nick Koston
2025-08-26 18:57:41 +02:00
parent fa41fabd7c
commit 361e46e2c6
2 changed files with 29 additions and 30 deletions

View File

@@ -47,7 +47,7 @@ async def ws_common_control(
) -> None: ) -> None:
"""Handle usage prediction common control WebSocket API.""" """Handle usage prediction common control WebSocket API."""
result = await get_cached_common_control(hass, connection.user.id) result = await get_cached_common_control(hass, connection.user.id)
time_category = common_control.time_category(dt_util.now()) time_category = common_control.time_category(dt_util.now().hour)
connection.send_result( connection.send_result(
msg["id"], msg["id"],
{ {

View File

@@ -2,11 +2,12 @@
from __future__ import annotations from __future__ import annotations
from collections import Counter, OrderedDict from collections import Counter
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
import json from functools import cache
import logging import logging
from typing import Any, cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -17,6 +18,7 @@ from homeassistant.components.recorder.models import uuid_hex_to_bytes_or_none
from homeassistant.components.recorder.util import session_scope from homeassistant.components.recorder.util import session_scope
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.json import json_loads_object
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -26,9 +28,9 @@ TIME_CATEGORIES = ["morning", "afternoon", "evening", "night"]
RESULTS_TO_SHOW = 8 RESULTS_TO_SHOW = 8
def time_category(dt: datetime) -> str: @cache
"""Determine the time category for a given datetime.""" def time_category(hour: int) -> str:
hour = dt.hour """Determine the time category for a given hour."""
if 6 <= hour < 12: if 6 <= hour < 12:
return "morning" return "morning"
if 12 <= hour < 18: if 12 <= hour < 18:
@@ -62,13 +64,13 @@ async def async_predict_common_control(
def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[str]]: def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[str]]:
"""Fetch and process service call events from the database.""" """Fetch and process service call events from the database."""
# Prepare a dictionary to track results # Prepare a dictionary to track results
results: OrderedDict[str, Counter[str]] = OrderedDict( results: dict[str, Counter[str]] = {
(time_cat, Counter()) for time_cat in TIME_CATEGORIES time_cat: Counter() for time_cat in TIME_CATEGORIES
) }
# Keep track of contexts that we processed so that we will only process # Keep track of contexts that we processed so that we will only process
# the first service call in a context, and not subsequent calls. # the first service call in a context, and not subsequent calls.
context_processed = set() context_processed: set[bytes] = set()
# Build the query to get call_service events # Build the query to get call_service events
# First, get the event_type_id for 'call_service' # First, get the event_type_id for 'call_service'
@@ -91,7 +93,6 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
query = ( query = (
select( select(
Events.context_id_bin, Events.context_id_bin,
Events.context_user_id_bin,
Events.time_fired_ts, Events.time_fired_ts,
EventData.shared_data, EventData.shared_data,
) )
@@ -104,11 +105,13 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
) )
# Execute the query # Execute the query
for row in session.execute(query): context_id: bytes
context_id = row.context_id_bin time_fired_ts: float
shared_data = row.shared_data shared_data: str | None
time_fired_ts = row.time_fired_ts local_time_zone = dt_util.get_default_time_zone()
for context_id, time_fired_ts, shared_data in (
session.connection().execute(query).all()
):
# Skip if we have already processed an event that was part of this context # Skip if we have already processed an event that was part of this context
if context_id in context_processed: if context_id in context_processed:
continue continue
@@ -118,7 +121,7 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
continue continue
try: try:
event_data = json.loads(shared_data) event_data = json_loads_object(shared_data)
except (ValueError, TypeError) as err: except (ValueError, TypeError) as err:
_LOGGER.debug("Failed to parse event data: %s", err) _LOGGER.debug("Failed to parse event data: %s", err)
continue continue
@@ -127,12 +130,13 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
if not event_data: if not event_data:
continue continue
service_data = event_data.get("service_data") service_data = cast(dict[str, Any] | None, event_data.get("service_data"))
# No service data found, skipping # No service data found, skipping
if not service_data: if not service_data:
continue continue
entity_ids: str | list[str] | None
if (target := service_data.get("target")) and ( if (target := service_data.get("target")) and (
target_entity_ids := target.get("entity_id") target_entity_ids := target.get("entity_id")
): ):
@@ -152,25 +156,20 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
# Convert timestamp to datetime and determine time category # Convert timestamp to datetime and determine time category
if time_fired_ts: if time_fired_ts:
time_fired = dt_util.utc_from_timestamp(time_fired_ts)
# Convert to local time for time category determination # Convert to local time for time category determination
local_time = dt_util.as_local(time_fired) period = time_category(
period = time_category(local_time) datetime.fromtimestamp(time_fired_ts, local_time_zone).hour
)
# Count entity usage # Count entity usage
for entity_id in entity_ids: for entity_id in entity_ids:
results[period][entity_id] += 1 results[period][entity_id] += 1
# Convert results to lists of top entities # Convert results to lists of top entities
final_results = {} return {
period: [ent_id for (ent_id, _) in period_results.most_common(RESULTS_TO_SHOW)]
for period, period_results in results.items(): for period, period_results in results.items()
entities = [ }
ent_id for (ent_id, _count) in period_results.most_common(RESULTS_TO_SHOW)
]
final_results[period] = entities
return final_results
def _fetch_with_session( def _fetch_with_session(