mirror of
https://github.com/home-assistant/core.git
synced 2025-08-30 09:51:37 +02:00
reduce overhead
This commit is contained in:
@@ -47,7 +47,7 @@ async def ws_common_control(
|
||||
) -> None:
|
||||
"""Handle usage prediction common control WebSocket API."""
|
||||
result = await get_cached_common_control(hass, connection.user.id)
|
||||
time_category = common_control.time_category(dt_util.now())
|
||||
time_category = common_control.time_category(dt_util.now().hour)
|
||||
connection.send_result(
|
||||
msg["id"],
|
||||
{
|
||||
|
@@ -2,11 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter, OrderedDict
|
||||
from collections import Counter
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
from functools import cache
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -17,6 +18,7 @@ from homeassistant.components.recorder.models import uuid_hex_to_bytes_or_none
|
||||
from homeassistant.components.recorder.util import session_scope
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.json import json_loads_object
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,9 +28,9 @@ TIME_CATEGORIES = ["morning", "afternoon", "evening", "night"]
|
||||
RESULTS_TO_SHOW = 8
|
||||
|
||||
|
||||
def time_category(dt: datetime) -> str:
|
||||
"""Determine the time category for a given datetime."""
|
||||
hour = dt.hour
|
||||
@cache
|
||||
def time_category(hour: int) -> str:
|
||||
"""Determine the time category for a given hour."""
|
||||
if 6 <= hour < 12:
|
||||
return "morning"
|
||||
if 12 <= hour < 18:
|
||||
@@ -62,13 +64,13 @@ async def async_predict_common_control(
|
||||
def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[str]]:
|
||||
"""Fetch and process service call events from the database."""
|
||||
# Prepare a dictionary to track results
|
||||
results: OrderedDict[str, Counter[str]] = OrderedDict(
|
||||
(time_cat, Counter()) for time_cat in TIME_CATEGORIES
|
||||
)
|
||||
results: dict[str, Counter[str]] = {
|
||||
time_cat: Counter() for time_cat in TIME_CATEGORIES
|
||||
}
|
||||
|
||||
# Keep track of contexts that we processed so that we will only process
|
||||
# the first service call in a context, and not subsequent calls.
|
||||
context_processed = set()
|
||||
context_processed: set[bytes] = set()
|
||||
|
||||
# Build the query to get call_service events
|
||||
# First, get the event_type_id for 'call_service'
|
||||
@@ -91,7 +93,6 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
|
||||
query = (
|
||||
select(
|
||||
Events.context_id_bin,
|
||||
Events.context_user_id_bin,
|
||||
Events.time_fired_ts,
|
||||
EventData.shared_data,
|
||||
)
|
||||
@@ -104,11 +105,13 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
for row in session.execute(query):
|
||||
context_id = row.context_id_bin
|
||||
shared_data = row.shared_data
|
||||
time_fired_ts = row.time_fired_ts
|
||||
|
||||
context_id: bytes
|
||||
time_fired_ts: float
|
||||
shared_data: str | None
|
||||
local_time_zone = dt_util.get_default_time_zone()
|
||||
for context_id, time_fired_ts, shared_data in (
|
||||
session.connection().execute(query).all()
|
||||
):
|
||||
# Skip if we have already processed an event that was part of this context
|
||||
if context_id in context_processed:
|
||||
continue
|
||||
@@ -118,7 +121,7 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
|
||||
continue
|
||||
|
||||
try:
|
||||
event_data = json.loads(shared_data)
|
||||
event_data = json_loads_object(shared_data)
|
||||
except (ValueError, TypeError) as err:
|
||||
_LOGGER.debug("Failed to parse event data: %s", err)
|
||||
continue
|
||||
@@ -127,12 +130,13 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
|
||||
if not event_data:
|
||||
continue
|
||||
|
||||
service_data = event_data.get("service_data")
|
||||
service_data = cast(dict[str, Any] | None, event_data.get("service_data"))
|
||||
|
||||
# No service data found, skipping
|
||||
if not service_data:
|
||||
continue
|
||||
|
||||
entity_ids: str | list[str] | None
|
||||
if (target := service_data.get("target")) and (
|
||||
target_entity_ids := target.get("entity_id")
|
||||
):
|
||||
@@ -152,25 +156,20 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
|
||||
|
||||
# Convert timestamp to datetime and determine time category
|
||||
if time_fired_ts:
|
||||
time_fired = dt_util.utc_from_timestamp(time_fired_ts)
|
||||
# Convert to local time for time category determination
|
||||
local_time = dt_util.as_local(time_fired)
|
||||
period = time_category(local_time)
|
||||
period = time_category(
|
||||
datetime.fromtimestamp(time_fired_ts, local_time_zone).hour
|
||||
)
|
||||
|
||||
# Count entity usage
|
||||
for entity_id in entity_ids:
|
||||
results[period][entity_id] += 1
|
||||
|
||||
# Convert results to lists of top entities
|
||||
final_results = {}
|
||||
|
||||
for period, period_results in results.items():
|
||||
entities = [
|
||||
ent_id for (ent_id, _count) in period_results.most_common(RESULTS_TO_SHOW)
|
||||
]
|
||||
final_results[period] = entities
|
||||
|
||||
return final_results
|
||||
return {
|
||||
period: [ent_id for (ent_id, _) in period_results.most_common(RESULTS_TO_SHOW)]
|
||||
for period, period_results in results.items()
|
||||
}
|
||||
|
||||
|
||||
def _fetch_with_session(
|
||||
|
Reference in New Issue
Block a user