reduce overhead

This commit is contained in:
J. Nick Koston
2025-08-26 18:57:41 +02:00
parent fa41fabd7c
commit 361e46e2c6
2 changed files with 29 additions and 30 deletions

View File

@@ -47,7 +47,7 @@ async def ws_common_control(
) -> None:
"""Handle usage prediction common control WebSocket API."""
result = await get_cached_common_control(hass, connection.user.id)
time_category = common_control.time_category(dt_util.now())
time_category = common_control.time_category(dt_util.now().hour)
connection.send_result(
msg["id"],
{

View File

@@ -2,11 +2,12 @@
from __future__ import annotations
from collections import Counter, OrderedDict
from collections import Counter
from collections.abc import Callable
from datetime import datetime, timedelta
import json
from functools import cache
import logging
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -17,6 +18,7 @@ from homeassistant.components.recorder.models import uuid_hex_to_bytes_or_none
from homeassistant.components.recorder.util import session_scope
from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util
from homeassistant.util.json import json_loads_object
_LOGGER = logging.getLogger(__name__)
@@ -26,9 +28,9 @@ TIME_CATEGORIES = ["morning", "afternoon", "evening", "night"]
RESULTS_TO_SHOW = 8
def time_category(dt: datetime) -> str:
"""Determine the time category for a given datetime."""
hour = dt.hour
@cache
def time_category(hour: int) -> str:
"""Determine the time category for a given hour."""
if 6 <= hour < 12:
return "morning"
if 12 <= hour < 18:
@@ -62,13 +64,13 @@ async def async_predict_common_control(
def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[str]]:
"""Fetch and process service call events from the database."""
# Prepare a dictionary to track results
results: OrderedDict[str, Counter[str]] = OrderedDict(
(time_cat, Counter()) for time_cat in TIME_CATEGORIES
)
results: dict[str, Counter[str]] = {
time_cat: Counter() for time_cat in TIME_CATEGORIES
}
# Keep track of contexts that we processed so that we will only process
# the first service call in a context, and not subsequent calls.
context_processed = set()
context_processed: set[bytes] = set()
# Build the query to get call_service events
# First, get the event_type_id for 'call_service'
@@ -91,7 +93,6 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
query = (
select(
Events.context_id_bin,
Events.context_user_id_bin,
Events.time_fired_ts,
EventData.shared_data,
)
@@ -104,11 +105,13 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
)
# Execute the query
for row in session.execute(query):
context_id = row.context_id_bin
shared_data = row.shared_data
time_fired_ts = row.time_fired_ts
context_id: bytes
time_fired_ts: float
shared_data: str | None
local_time_zone = dt_util.get_default_time_zone()
for context_id, time_fired_ts, shared_data in (
session.connection().execute(query).all()
):
# Skip if we have already processed an event that was part of this context
if context_id in context_processed:
continue
@@ -118,7 +121,7 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
continue
try:
event_data = json.loads(shared_data)
event_data = json_loads_object(shared_data)
except (ValueError, TypeError) as err:
_LOGGER.debug("Failed to parse event data: %s", err)
continue
@@ -127,12 +130,13 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
if not event_data:
continue
service_data = event_data.get("service_data")
service_data = cast(dict[str, Any] | None, event_data.get("service_data"))
# No service data found, skipping
if not service_data:
continue
entity_ids: str | list[str] | None
if (target := service_data.get("target")) and (
target_entity_ids := target.get("entity_id")
):
@@ -152,25 +156,20 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
# Convert timestamp to datetime and determine time category
if time_fired_ts:
time_fired = dt_util.utc_from_timestamp(time_fired_ts)
# Convert to local time for time category determination
local_time = dt_util.as_local(time_fired)
period = time_category(local_time)
period = time_category(
datetime.fromtimestamp(time_fired_ts, local_time_zone).hour
)
# Count entity usage
for entity_id in entity_ids:
results[period][entity_id] += 1
# Convert results to lists of top entities
final_results = {}
for period, period_results in results.items():
entities = [
ent_id for (ent_id, _count) in period_results.most_common(RESULTS_TO_SHOW)
]
final_results[period] = entities
return final_results
return {
period: [ent_id for (ent_id, _) in period_results.most_common(RESULTS_TO_SHOW)]
for period, period_results in results.items()
}
def _fetch_with_session(