Add usage_prediction integration

This commit is contained in:
Paulus Schoutsen
2025-08-26 17:59:04 +02:00
parent ce523fc91d
commit fa41fabd7c
9 changed files with 778 additions and 0 deletions

View File

@@ -0,0 +1,101 @@
"""The usage prediction integration."""
from __future__ import annotations
from datetime import timedelta
from typing import Any
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.storage import Store
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import dt as dt_util
from . import common_control
from .const import DOMAIN
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
# Storage configuration
STORAGE_VERSION = 1
STORAGE_KEY_PREFIX = f"{DOMAIN}.common_control"
CACHE_DURATION = timedelta(hours=24)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the usage prediction integration."""
websocket_api.async_register_command(hass, ws_common_control)
# Initialize domain data storage
hass.data[DOMAIN] = {}
return True
@websocket_api.websocket_command(
{
"type": f"{DOMAIN}/common_control",
"user_id": cv.string,
}
)
@websocket_api.async_response
async def ws_common_control(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Handle usage prediction common control WebSocket API."""
result = await get_cached_common_control(hass, connection.user.id)
time_category = common_control.time_category(dt_util.now())
connection.send_result(
msg["id"],
{
"entities": result[time_category],
},
)
async def get_cached_common_control(
hass: HomeAssistant, user_id: str
) -> dict[str, list[str]]:
"""Get cached common control predictions or fetch new ones.
Returns cached data if it's less than 24 hours old,
otherwise fetches new data and caches it.
"""
# Create a unique storage key for this user
storage_key = f"{STORAGE_KEY_PREFIX}.{user_id}"
# Get or create store for this user
if storage_key not in hass.data[DOMAIN]:
hass.data[DOMAIN][storage_key] = Store[dict[str, Any]](
hass, STORAGE_VERSION, storage_key, private=True
)
store: Store[dict[str, Any]] = hass.data[DOMAIN][storage_key]
# Load cached data
cached_data = await store.async_load()
# Check if cache is valid (less than 24 hours old)
now = dt_util.utcnow()
if cached_data is not None:
cached_time = dt_util.parse_datetime(cached_data.get("timestamp", ""))
if cached_time and (now - cached_time) < CACHE_DURATION:
# Cache is still valid, return the cached predictions
return cached_data["predictions"]
# Cache is expired or doesn't exist, fetch new data
predictions = await common_control.async_predict_common_control(hass, user_id)
# Store the new data with timestamp
cache_data = {
"timestamp": now.isoformat(),
"predictions": predictions,
}
# Save to cache
await store.async_save(cache_data)
return predictions

View File

@@ -0,0 +1,181 @@
"""Code to generate common control usage patterns."""
from __future__ import annotations
from collections import Counter, OrderedDict
from collections.abc import Callable
from datetime import datetime, timedelta
import json
import logging
from sqlalchemy import select
from sqlalchemy.orm import Session
from homeassistant.components.recorder import get_instance
from homeassistant.components.recorder.db_schema import EventData, Events, EventTypes
from homeassistant.components.recorder.models import uuid_hex_to_bytes_or_none
from homeassistant.components.recorder.util import session_scope
from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util
_LOGGER = logging.getLogger(__name__)
# Time categories for usage patterns
TIME_CATEGORIES = ["morning", "afternoon", "evening", "night"]
RESULTS_TO_SHOW = 8
def time_category(dt: datetime) -> str:
"""Determine the time category for a given datetime."""
hour = dt.hour
if 6 <= hour < 12:
return "morning"
if 12 <= hour < 18:
return "afternoon"
if 18 <= hour < 22:
return "evening"
return "night"
async def async_predict_common_control(
hass: HomeAssistant, user_id: str
) -> dict[str, list[str]]:
"""Generate a list of commonly used entities for a user.
Args:
hass: Home Assistant instance
user_id: User ID to filter events by.
Returns:
Dictionary with time categories as keys and lists of most common entity IDs as values
"""
# Get the recorder instance to ensure it's ready
recorder = get_instance(hass)
# Execute the database operation in the recorder's executor
return await recorder.async_add_executor_job(
_fetch_with_session, hass, _fetch_and_process_data, user_id
)
def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[str]]:
"""Fetch and process service call events from the database."""
# Prepare a dictionary to track results
results: OrderedDict[str, Counter[str]] = OrderedDict(
(time_cat, Counter()) for time_cat in TIME_CATEGORIES
)
# Keep track of contexts that we processed so that we will only process
# the first service call in a context, and not subsequent calls.
context_processed = set()
# Build the query to get call_service events
# First, get the event_type_id for 'call_service'
event_type_query = select(EventTypes.event_type_id).where(
EventTypes.event_type == "call_service"
)
event_type_result = session.execute(event_type_query).first()
if not event_type_result:
_LOGGER.warning("No call_service events found in database")
return {time_cat: [] for time_cat in TIME_CATEGORIES}
call_service_type_id = event_type_result[0]
thirty_days_ago_ts = (dt_util.utcnow() - timedelta(days=30)).timestamp()
user_id_bytes = uuid_hex_to_bytes_or_none(user_id)
if not user_id_bytes:
raise ValueError("Invalid user_id format")
# Build the main query for events with their data
query = (
select(
Events.context_id_bin,
Events.context_user_id_bin,
Events.time_fired_ts,
EventData.shared_data,
)
.select_from(Events)
.outerjoin(EventData, Events.data_id == EventData.data_id)
.where(Events.event_type_id == call_service_type_id)
.where(Events.time_fired_ts >= thirty_days_ago_ts)
.where(Events.context_user_id_bin == user_id_bytes)
.order_by(Events.time_fired_ts)
)
# Execute the query
for row in session.execute(query):
context_id = row.context_id_bin
shared_data = row.shared_data
time_fired_ts = row.time_fired_ts
# Skip if we have already processed an event that was part of this context
if context_id in context_processed:
continue
# Parse the event data
if not shared_data:
continue
try:
event_data = json.loads(shared_data)
except (ValueError, TypeError) as err:
_LOGGER.debug("Failed to parse event data: %s", err)
continue
# Empty event data, skipping
if not event_data:
continue
service_data = event_data.get("service_data")
# No service data found, skipping
if not service_data:
continue
if (target := service_data.get("target")) and (
target_entity_ids := target.get("entity_id")
):
entity_ids = target_entity_ids
else:
entity_ids = service_data.get("entity_id")
# No entity IDs found, skip this event
if entity_ids is None:
continue
if not isinstance(entity_ids, list):
entity_ids = [entity_ids]
# Mark this context as processed
context_processed.add(context_id)
# Convert timestamp to datetime and determine time category
if time_fired_ts:
time_fired = dt_util.utc_from_timestamp(time_fired_ts)
# Convert to local time for time category determination
local_time = dt_util.as_local(time_fired)
period = time_category(local_time)
# Count entity usage
for entity_id in entity_ids:
results[period][entity_id] += 1
# Convert results to lists of top entities
final_results = {}
for period, period_results in results.items():
entities = [
ent_id for (ent_id, _count) in period_results.most_common(RESULTS_TO_SHOW)
]
final_results[period] = entities
return final_results
def _fetch_with_session(
hass: HomeAssistant, fetch_func: Callable[[Session], dict[str, list[str]]], *args
) -> dict[str, list[str]]:
"""Execute a fetch function with a database session."""
with session_scope(hass=hass, read_only=True) as session:
return fetch_func(session, *args)

View File

@@ -0,0 +1,3 @@
"""Constants for the usage prediction integration."""
DOMAIN = "usage_prediction"

View File

@@ -0,0 +1,10 @@
{
"domain": "usage_prediction",
"name": "Usage Prediction",
"codeowners": ["@home-assistant/core"],
"dependencies": ["http", "recorder"],
"documentation": "https://www.home-assistant.io/integrations/usage_prediction",
"integration_type": "system",
"iot_class": "calculated",
"quality_scale": "internal"
}

View File

@@ -0,0 +1,20 @@
rules:
# Bronze
action-setup: exempt
appropriate-polling: exempt
brands: todo
common-modules: done
config-flow-test-coverage: exempt
config-flow: exempt
dependency-transparency: done
docs-actions: exempt
docs-high-level-description: todo
docs-installation-instructions: exempt
docs-removal-instructions: exempt
entity-event-setup: exempt
entity-unique-id: exempt
has-entity-name: exempt
runtime-data: exempt
test-before-configure: exempt
test-before-setup: exempt
unique-config-entry: exempt

View File

@@ -0,0 +1,3 @@
{
"title": "Usage Prediction"
}

View File

@@ -0,0 +1 @@
"""Tests for the usage_prediction integration."""

View File

@@ -0,0 +1,354 @@
"""Test the common control usage prediction."""
from __future__ import annotations
from unittest.mock import patch
import uuid
from freezegun import freeze_time
import pytest
from homeassistant.components.usage_prediction.common_control import (
async_predict_common_control,
)
from homeassistant.const import EVENT_CALL_SERVICE
from homeassistant.core import Context, HomeAssistant
from tests.components.recorder.common import async_wait_recording_done
@pytest.mark.usefixtures("recorder_mock")
async def test_empty_database(hass: HomeAssistant) -> None:
"""Test function with empty database returns empty results."""
user_id = str(uuid.uuid4())
# Call the function with empty database
results = await async_predict_common_control(hass, user_id)
# Should return empty lists for all time categories
assert results == {
"morning": [],
"afternoon": [],
"evening": [],
"night": [],
}
@pytest.mark.usefixtures("recorder_mock")
async def test_invalid_user_id(hass: HomeAssistant) -> None:
"""Test function with invalid user ID returns empty results."""
# Invalid user ID format (not a valid UUID)
results = await async_predict_common_control(hass, "invalid-user-id")
# Should return empty lists for all time categories due to invalid user ID
assert results == {
"morning": [],
"afternoon": [],
"evening": [],
"night": [],
}
@pytest.mark.usefixtures("recorder_mock")
async def test_with_service_calls(hass: HomeAssistant) -> None:
"""Test function with actual service call events in database."""
user_id = str(uuid.uuid4())
# Create service call events at different times of day
# Morning events - use separate service calls to get around context deduplication
with freeze_time("2023-07-01 07:00:00+00:00"): # Morning
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
"domain": "light",
"service": "turn_on",
"service_data": {"entity_id": ["light.living_room", "light.kitchen"]},
},
context=Context(user_id=user_id),
)
await hass.async_block_till_done()
# Afternoon events
with freeze_time("2023-07-01 14:00:00+00:00"): # Afternoon
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
"domain": "climate",
"service": "set_temperature",
"service_data": {"entity_id": "climate.thermostat"},
},
context=Context(user_id=user_id),
)
await hass.async_block_till_done()
# Evening events
with freeze_time("2023-07-01 19:00:00+00:00"): # Evening
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
"domain": "light",
"service": "turn_off",
"service_data": {"entity_id": "light.bedroom"},
},
context=Context(user_id=user_id),
)
await hass.async_block_till_done()
# Night events
with freeze_time("2023-07-01 23:00:00+00:00"): # Night
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
"domain": "lock",
"service": "lock",
"service_data": {"entity_id": "lock.front_door"},
},
context=Context(user_id=user_id),
)
await hass.async_block_till_done()
# Wait for events to be recorded
await async_wait_recording_done(hass)
# Get predictions - make sure we're still in a reasonable timeframe
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
results = await async_predict_common_control(hass, user_id)
# Verify results contain the expected entities in the correct time periods
assert results == {
"morning": ["climate.thermostat"],
"afternoon": ["light.bedroom", "lock.front_door"],
"evening": [],
"night": ["light.living_room", "light.kitchen"],
}
@pytest.mark.usefixtures("recorder_mock")
async def test_multiple_entities_in_one_call(hass: HomeAssistant) -> None:
"""Test handling of service calls with multiple entity IDs."""
user_id = str(uuid.uuid4())
with freeze_time("2023-07-01 10:00:00+00:00"): # Morning
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
"domain": "light",
"service": "turn_on",
"service_data": {
"entity_id": ["light.living_room", "light.kitchen", "light.hallway"]
},
},
context=Context(user_id=user_id),
)
await hass.async_block_till_done()
await async_wait_recording_done(hass)
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
results = await async_predict_common_control(hass, user_id)
# All three lights should be counted (10:00 UTC = 02:00 local = night)
assert results["night"] == ["light.living_room", "light.kitchen", "light.hallway"]
assert results["morning"] == []
assert results["afternoon"] == []
assert results["evening"] == []
@pytest.mark.usefixtures("recorder_mock")
async def test_context_deduplication(hass: HomeAssistant) -> None:
"""Test that multiple events with the same context are deduplicated."""
user_id = str(uuid.uuid4())
context = Context(user_id=user_id)
with freeze_time("2023-07-01 10:00:00+00:00"): # Morning
# Fire multiple events with the same context
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
"domain": "light",
"service": "turn_on",
"service_data": {"entity_id": "light.living_room"},
},
context=context,
)
await hass.async_block_till_done()
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
"domain": "switch",
"service": "turn_on",
"service_data": {"entity_id": "switch.coffee_maker"},
},
context=context, # Same context
)
await hass.async_block_till_done()
await async_wait_recording_done(hass)
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
results = await async_predict_common_control(hass, user_id)
# Only the first event should be processed (10:00 UTC = 02:00 local = night)
assert results == {
"morning": [],
"afternoon": [],
"evening": [],
"night": ["light.living_room"],
}
@pytest.mark.usefixtures("recorder_mock")
async def test_old_events_excluded(hass: HomeAssistant) -> None:
"""Test that events older than 30 days are excluded."""
user_id = str(uuid.uuid4())
# Create an old event (35 days ago)
with freeze_time("2023-05-27 10:00:00+00:00"): # 35 days before July 1st
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
"domain": "light",
"service": "turn_on",
"service_data": {"entity_id": "light.old_event"},
},
context=Context(user_id=user_id),
)
await hass.async_block_till_done()
# Create a recent event (5 days ago)
with freeze_time("2023-06-26 10:00:00+00:00"): # 5 days before July 1st
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
"domain": "light",
"service": "turn_on",
"service_data": {"entity_id": "light.recent_event"},
},
context=Context(user_id=user_id),
)
await hass.async_block_till_done()
await async_wait_recording_done(hass)
# Query with current time
with freeze_time("2023-07-01 10:00:00+00:00"):
results = await async_predict_common_control(hass, user_id)
# Only recent event should be included (10:00 UTC = 02:00 local = night)
assert results == {
"morning": [],
"afternoon": [],
"evening": [],
"night": ["light.recent_event"],
}
@pytest.mark.usefixtures("recorder_mock")
async def test_entities_limit(hass: HomeAssistant) -> None:
"""Test that only top entities are returned per time category."""
user_id = str(uuid.uuid4())
# Create more than 5 different entities in morning
with freeze_time("2023-07-01 08:00:00+00:00"):
# Create entities with different frequencies
entities_with_counts = [
("light.most_used", 10),
("light.second", 8),
("light.third", 6),
("light.fourth", 4),
("light.fifth", 2),
("light.sixth", 1),
("light.seventh", 1),
]
for entity_id, count in entities_with_counts:
for _ in range(count):
# Use different context for each call
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
"domain": "light",
"service": "toggle",
"service_data": {"entity_id": entity_id},
},
context=Context(user_id=user_id),
)
await hass.async_block_till_done()
await async_wait_recording_done(hass)
with (
freeze_time("2023-07-02 10:00:00+00:00"),
patch(
"homeassistant.components.usage_prediction.common_control.RESULTS_TO_SHOW",
5,
),
): # Next day, so events are recent
results = await async_predict_common_control(hass, user_id)
# Should be the top 5 most used (08:00 UTC = 00:00 local = night)
assert results["night"] == [
"light.most_used",
"light.second",
"light.third",
"light.fourth",
"light.fifth",
]
assert results["morning"] == []
assert results["afternoon"] == []
assert results["evening"] == []
@pytest.mark.usefixtures("recorder_mock")
async def test_different_users_separated(hass: HomeAssistant) -> None:
"""Test that events from different users are properly separated."""
user_id_1 = str(uuid.uuid4())
user_id_2 = str(uuid.uuid4())
with freeze_time("2023-07-01 10:00:00+00:00"):
# User 1 events
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
"domain": "light",
"service": "turn_on",
"service_data": {"entity_id": "light.user1_light"},
},
context=Context(user_id=user_id_1),
)
await hass.async_block_till_done()
# User 2 events
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
"domain": "light",
"service": "turn_on",
"service_data": {"entity_id": "light.user2_light"},
},
context=Context(user_id=user_id_2),
)
await hass.async_block_till_done()
await async_wait_recording_done(hass)
# Get results for each user
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
results_user1 = await async_predict_common_control(hass, user_id_1)
results_user2 = await async_predict_common_control(hass, user_id_2)
# Each user should only see their own entities (10:00 UTC = 02:00 local = night)
assert results_user1 == {
"morning": [],
"afternoon": [],
"evening": [],
"night": ["light.user1_light"],
}
assert results_user2 == {
"morning": [],
"afternoon": [],
"evening": [],
"night": ["light.user2_light"],
}

View File

@@ -0,0 +1,105 @@
"""Test usage_prediction WebSocket API."""
from collections.abc import Generator
from datetime import timedelta
from unittest.mock import Mock, patch
from freezegun import freeze_time
import pytest
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
from tests.common import MockUser
from tests.typing import WebSocketGenerator
@pytest.fixture
def mock_predict_common_control() -> Generator[Mock]:
"""Return a mock result for common control."""
with patch(
"homeassistant.components.usage_prediction.common_control.async_predict_common_control",
return_value={
"morning": ["light.kitchen"],
"afternoon": ["climate.thermostat"],
"evening": ["light.bedroom"],
"night": ["lock.front_door"],
},
) as mock_predict:
yield mock_predict
@pytest.mark.usefixtures("recorder_mock")
async def test_common_control(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
hass_admin_user: MockUser,
mock_predict_common_control: Mock,
) -> None:
"""Test usage_prediction common control WebSocket command."""
assert await async_setup_component(hass, "usage_prediction", {})
client = await hass_ws_client(hass)
await client.send_json({"id": 1, "type": "usage_prediction/common_control"})
msg = await client.receive_json()
assert msg["id"] == 1
assert msg["type"] == "result"
assert msg["success"] is True
assert msg["result"] == {
"entities": [
"light.kitchen",
]
}
assert mock_predict_common_control.call_count == 1
assert mock_predict_common_control.mock_calls[0][1][1] == hass_admin_user.id
@pytest.mark.usefixtures("recorder_mock")
async def test_caching_behavior(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
mock_predict_common_control: Mock,
) -> None:
"""Test that results are cached for 24 hours."""
assert await async_setup_component(hass, "usage_prediction", {})
client = await hass_ws_client(hass)
now = dt_util.utcnow()
# First call should fetch from database
with freeze_time(now):
await client.send_json({"id": 1, "type": "usage_prediction/common_control"})
msg = await client.receive_json()
assert msg["success"] is True
assert msg["result"] == {
"entities": [
"light.kitchen",
]
}
assert mock_predict_common_control.call_count == 1
mock_predict_common_control.return_value["morning"].append("light.bla")
# Second call within 24 hours should use cache
with freeze_time(now + timedelta(hours=23)):
await client.send_json({"id": 2, "type": "usage_prediction/common_control"})
msg = await client.receive_json()
assert msg["success"] is True
assert msg["result"] == {
"entities": [
"light.kitchen",
]
}
# Should still be 1 (no new database call)
assert mock_predict_common_control.call_count == 1
# Third call after 24 hours should fetch from database again
with freeze_time(now + timedelta(hours=25)):
await client.send_json({"id": 3, "type": "usage_prediction/common_control"})
msg = await client.receive_json()
assert msg["success"] is True
assert msg["result"] == {"entities": ["light.kitchen", "light.bla"]}
# Should now be 2 (new database call)
assert mock_predict_common_control.call_count == 2