diff --git a/homeassistant/components/usage_prediction/__init__.py b/homeassistant/components/usage_prediction/__init__.py index a6688139d4a..c95e73e4351 100644 --- a/homeassistant/components/usage_prediction/__init__.py +++ b/homeassistant/components/usage_prediction/__init__.py @@ -47,7 +47,7 @@ async def ws_common_control( ) -> None: """Handle usage prediction common control WebSocket API.""" result = await get_cached_common_control(hass, connection.user.id) - time_category = common_control.time_category(dt_util.now()) + time_category = common_control.time_category(dt_util.now().hour) connection.send_result( msg["id"], { diff --git a/homeassistant/components/usage_prediction/common_control.py b/homeassistant/components/usage_prediction/common_control.py index b110187cc0a..05025a52afa 100644 --- a/homeassistant/components/usage_prediction/common_control.py +++ b/homeassistant/components/usage_prediction/common_control.py @@ -2,11 +2,12 @@ from __future__ import annotations -from collections import Counter, OrderedDict +from collections import Counter from collections.abc import Callable from datetime import datetime, timedelta -import json +from functools import cache import logging +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -17,6 +18,7 @@ from homeassistant.components.recorder.models import uuid_hex_to_bytes_or_none from homeassistant.components.recorder.util import session_scope from homeassistant.core import HomeAssistant from homeassistant.util import dt as dt_util +from homeassistant.util.json import json_loads_object _LOGGER = logging.getLogger(__name__) @@ -26,9 +28,9 @@ TIME_CATEGORIES = ["morning", "afternoon", "evening", "night"] RESULTS_TO_SHOW = 8 -def time_category(dt: datetime) -> str: - """Determine the time category for a given datetime.""" - hour = dt.hour +@cache +def time_category(hour: int) -> str: + """Determine the time category for a given hour.""" if 6 <= hour < 12: return "morning" if 12 <= hour < 18: @@ -62,13 +64,13 @@ async def async_predict_common_control( def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[str]]: """Fetch and process service call events from the database.""" # Prepare a dictionary to track results - results: OrderedDict[str, Counter[str]] = OrderedDict( - (time_cat, Counter()) for time_cat in TIME_CATEGORIES - ) + results: dict[str, Counter[str]] = { + time_cat: Counter() for time_cat in TIME_CATEGORIES + } # Keep track of contexts that we processed so that we will only process # the first service call in a context, and not subsequent calls. - context_processed = set() + context_processed: set[bytes] = set() # Build the query to get call_service events # First, get the event_type_id for 'call_service' @@ -91,7 +93,6 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st query = ( select( Events.context_id_bin, - Events.context_user_id_bin, Events.time_fired_ts, EventData.shared_data, ) @@ -104,11 +105,13 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st ) # Execute the query - for row in session.execute(query): - context_id = row.context_id_bin - shared_data = row.shared_data - time_fired_ts = row.time_fired_ts - + context_id: bytes + time_fired_ts: float + shared_data: str | None + local_time_zone = dt_util.get_default_time_zone() + for context_id, time_fired_ts, shared_data in ( + session.connection().execute(query).all() + ): # Skip if we have already processed an event that was part of this context if context_id in context_processed: continue @@ -118,7 +121,7 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st continue try: - event_data = json.loads(shared_data) + event_data = json_loads_object(shared_data) except (ValueError, TypeError) as err: _LOGGER.debug("Failed to parse event data: %s", err) continue @@ -127,12 +130,13 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st if not event_data: continue - service_data = event_data.get("service_data") + service_data = cast(dict[str, Any] | None, event_data.get("service_data")) # No service data found, skipping if not service_data: continue + entity_ids: str | list[str] | None if (target := service_data.get("target")) and ( target_entity_ids := target.get("entity_id") ): @@ -152,25 +156,20 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st # Convert timestamp to datetime and determine time category if time_fired_ts: - time_fired = dt_util.utc_from_timestamp(time_fired_ts) # Convert to local time for time category determination - local_time = dt_util.as_local(time_fired) - period = time_category(local_time) + period = time_category( + datetime.fromtimestamp(time_fired_ts, local_time_zone).hour + ) # Count entity usage for entity_id in entity_ids: results[period][entity_id] += 1 # Convert results to lists of top entities - final_results = {} - - for period, period_results in results.items(): - entities = [ - ent_id for (ent_id, _count) in period_results.most_common(RESULTS_TO_SHOW) - ] - final_results[period] = entities - - return final_results + return { + period: [ent_id for (ent_id, _) in period_results.most_common(RESULTS_TO_SHOW)] + for period, period_results in results.items() + } def _fetch_with_session(