mirror of
https://github.com/home-assistant/core.git
synced 2025-08-31 10:21:30 +02:00
reduce overhead
This commit is contained in:
@@ -47,7 +47,7 @@ async def ws_common_control(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Handle usage prediction common control WebSocket API."""
|
"""Handle usage prediction common control WebSocket API."""
|
||||||
result = await get_cached_common_control(hass, connection.user.id)
|
result = await get_cached_common_control(hass, connection.user.id)
|
||||||
time_category = common_control.time_category(dt_util.now())
|
time_category = common_control.time_category(dt_util.now().hour)
|
||||||
connection.send_result(
|
connection.send_result(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
{
|
{
|
||||||
|
@@ -2,11 +2,12 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import Counter, OrderedDict
|
from collections import Counter
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import json
|
from functools import cache
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -17,6 +18,7 @@ from homeassistant.components.recorder.models import uuid_hex_to_bytes_or_none
|
|||||||
from homeassistant.components.recorder.util import session_scope
|
from homeassistant.components.recorder.util import session_scope
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
from homeassistant.util.json import json_loads_object
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -26,9 +28,9 @@ TIME_CATEGORIES = ["morning", "afternoon", "evening", "night"]
|
|||||||
RESULTS_TO_SHOW = 8
|
RESULTS_TO_SHOW = 8
|
||||||
|
|
||||||
|
|
||||||
def time_category(dt: datetime) -> str:
|
@cache
|
||||||
"""Determine the time category for a given datetime."""
|
def time_category(hour: int) -> str:
|
||||||
hour = dt.hour
|
"""Determine the time category for a given hour."""
|
||||||
if 6 <= hour < 12:
|
if 6 <= hour < 12:
|
||||||
return "morning"
|
return "morning"
|
||||||
if 12 <= hour < 18:
|
if 12 <= hour < 18:
|
||||||
@@ -62,13 +64,13 @@ async def async_predict_common_control(
|
|||||||
def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[str]]:
|
def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[str]]:
|
||||||
"""Fetch and process service call events from the database."""
|
"""Fetch and process service call events from the database."""
|
||||||
# Prepare a dictionary to track results
|
# Prepare a dictionary to track results
|
||||||
results: OrderedDict[str, Counter[str]] = OrderedDict(
|
results: dict[str, Counter[str]] = {
|
||||||
(time_cat, Counter()) for time_cat in TIME_CATEGORIES
|
time_cat: Counter() for time_cat in TIME_CATEGORIES
|
||||||
)
|
}
|
||||||
|
|
||||||
# Keep track of contexts that we processed so that we will only process
|
# Keep track of contexts that we processed so that we will only process
|
||||||
# the first service call in a context, and not subsequent calls.
|
# the first service call in a context, and not subsequent calls.
|
||||||
context_processed = set()
|
context_processed: set[bytes] = set()
|
||||||
|
|
||||||
# Build the query to get call_service events
|
# Build the query to get call_service events
|
||||||
# First, get the event_type_id for 'call_service'
|
# First, get the event_type_id for 'call_service'
|
||||||
@@ -91,7 +93,6 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
|
|||||||
query = (
|
query = (
|
||||||
select(
|
select(
|
||||||
Events.context_id_bin,
|
Events.context_id_bin,
|
||||||
Events.context_user_id_bin,
|
|
||||||
Events.time_fired_ts,
|
Events.time_fired_ts,
|
||||||
EventData.shared_data,
|
EventData.shared_data,
|
||||||
)
|
)
|
||||||
@@ -104,11 +105,13 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Execute the query
|
# Execute the query
|
||||||
for row in session.execute(query):
|
context_id: bytes
|
||||||
context_id = row.context_id_bin
|
time_fired_ts: float
|
||||||
shared_data = row.shared_data
|
shared_data: str | None
|
||||||
time_fired_ts = row.time_fired_ts
|
local_time_zone = dt_util.get_default_time_zone()
|
||||||
|
for context_id, time_fired_ts, shared_data in (
|
||||||
|
session.connection().execute(query).all()
|
||||||
|
):
|
||||||
# Skip if we have already processed an event that was part of this context
|
# Skip if we have already processed an event that was part of this context
|
||||||
if context_id in context_processed:
|
if context_id in context_processed:
|
||||||
continue
|
continue
|
||||||
@@ -118,7 +121,7 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
event_data = json.loads(shared_data)
|
event_data = json_loads_object(shared_data)
|
||||||
except (ValueError, TypeError) as err:
|
except (ValueError, TypeError) as err:
|
||||||
_LOGGER.debug("Failed to parse event data: %s", err)
|
_LOGGER.debug("Failed to parse event data: %s", err)
|
||||||
continue
|
continue
|
||||||
@@ -127,12 +130,13 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
|
|||||||
if not event_data:
|
if not event_data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
service_data = event_data.get("service_data")
|
service_data = cast(dict[str, Any] | None, event_data.get("service_data"))
|
||||||
|
|
||||||
# No service data found, skipping
|
# No service data found, skipping
|
||||||
if not service_data:
|
if not service_data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
entity_ids: str | list[str] | None
|
||||||
if (target := service_data.get("target")) and (
|
if (target := service_data.get("target")) and (
|
||||||
target_entity_ids := target.get("entity_id")
|
target_entity_ids := target.get("entity_id")
|
||||||
):
|
):
|
||||||
@@ -152,25 +156,20 @@ def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[st
|
|||||||
|
|
||||||
# Convert timestamp to datetime and determine time category
|
# Convert timestamp to datetime and determine time category
|
||||||
if time_fired_ts:
|
if time_fired_ts:
|
||||||
time_fired = dt_util.utc_from_timestamp(time_fired_ts)
|
|
||||||
# Convert to local time for time category determination
|
# Convert to local time for time category determination
|
||||||
local_time = dt_util.as_local(time_fired)
|
period = time_category(
|
||||||
period = time_category(local_time)
|
datetime.fromtimestamp(time_fired_ts, local_time_zone).hour
|
||||||
|
)
|
||||||
|
|
||||||
# Count entity usage
|
# Count entity usage
|
||||||
for entity_id in entity_ids:
|
for entity_id in entity_ids:
|
||||||
results[period][entity_id] += 1
|
results[period][entity_id] += 1
|
||||||
|
|
||||||
# Convert results to lists of top entities
|
# Convert results to lists of top entities
|
||||||
final_results = {}
|
return {
|
||||||
|
period: [ent_id for (ent_id, _) in period_results.most_common(RESULTS_TO_SHOW)]
|
||||||
for period, period_results in results.items():
|
for period, period_results in results.items()
|
||||||
entities = [
|
}
|
||||||
ent_id for (ent_id, _count) in period_results.most_common(RESULTS_TO_SHOW)
|
|
||||||
]
|
|
||||||
final_results[period] = entities
|
|
||||||
|
|
||||||
return final_results
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_with_session(
|
def _fetch_with_session(
|
||||||
|
Reference in New Issue
Block a user