mirror of
https://github.com/home-assistant/core.git
synced 2025-08-31 18:31:35 +02:00
Add usage_prediction integration
This commit is contained in:
101
homeassistant/components/usage_prediction/__init__.py
Normal file
101
homeassistant/components/usage_prediction/__init__.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""The usage prediction integration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.components import websocket_api
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
from homeassistant.helpers.storage import Store
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
|
from . import common_control
|
||||||
|
from .const import DOMAIN
|
||||||
|
|
||||||
|
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||||
|
|
||||||
|
# Storage configuration
|
||||||
|
STORAGE_VERSION = 1
|
||||||
|
STORAGE_KEY_PREFIX = f"{DOMAIN}.common_control"
|
||||||
|
CACHE_DURATION = timedelta(hours=24)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
|
"""Set up the usage prediction integration."""
|
||||||
|
websocket_api.async_register_command(hass, ws_common_control)
|
||||||
|
|
||||||
|
# Initialize domain data storage
|
||||||
|
hass.data[DOMAIN] = {}
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
"type": f"{DOMAIN}/common_control",
|
||||||
|
"user_id": cv.string,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@websocket_api.async_response
|
||||||
|
async def ws_common_control(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
connection: websocket_api.ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Handle usage prediction common control WebSocket API."""
|
||||||
|
result = await get_cached_common_control(hass, connection.user.id)
|
||||||
|
time_category = common_control.time_category(dt_util.now())
|
||||||
|
connection.send_result(
|
||||||
|
msg["id"],
|
||||||
|
{
|
||||||
|
"entities": result[time_category],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_cached_common_control(
|
||||||
|
hass: HomeAssistant, user_id: str
|
||||||
|
) -> dict[str, list[str]]:
|
||||||
|
"""Get cached common control predictions or fetch new ones.
|
||||||
|
|
||||||
|
Returns cached data if it's less than 24 hours old,
|
||||||
|
otherwise fetches new data and caches it.
|
||||||
|
"""
|
||||||
|
# Create a unique storage key for this user
|
||||||
|
storage_key = f"{STORAGE_KEY_PREFIX}.{user_id}"
|
||||||
|
|
||||||
|
# Get or create store for this user
|
||||||
|
if storage_key not in hass.data[DOMAIN]:
|
||||||
|
hass.data[DOMAIN][storage_key] = Store[dict[str, Any]](
|
||||||
|
hass, STORAGE_VERSION, storage_key, private=True
|
||||||
|
)
|
||||||
|
|
||||||
|
store: Store[dict[str, Any]] = hass.data[DOMAIN][storage_key]
|
||||||
|
|
||||||
|
# Load cached data
|
||||||
|
cached_data = await store.async_load()
|
||||||
|
|
||||||
|
# Check if cache is valid (less than 24 hours old)
|
||||||
|
now = dt_util.utcnow()
|
||||||
|
if cached_data is not None:
|
||||||
|
cached_time = dt_util.parse_datetime(cached_data.get("timestamp", ""))
|
||||||
|
if cached_time and (now - cached_time) < CACHE_DURATION:
|
||||||
|
# Cache is still valid, return the cached predictions
|
||||||
|
return cached_data["predictions"]
|
||||||
|
|
||||||
|
# Cache is expired or doesn't exist, fetch new data
|
||||||
|
predictions = await common_control.async_predict_common_control(hass, user_id)
|
||||||
|
|
||||||
|
# Store the new data with timestamp
|
||||||
|
cache_data = {
|
||||||
|
"timestamp": now.isoformat(),
|
||||||
|
"predictions": predictions,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save to cache
|
||||||
|
await store.async_save(cache_data)
|
||||||
|
|
||||||
|
return predictions
|
181
homeassistant/components/usage_prediction/common_control.py
Normal file
181
homeassistant/components/usage_prediction/common_control.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
"""Code to generate common control usage patterns."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import Counter, OrderedDict
|
||||||
|
from collections.abc import Callable
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from homeassistant.components.recorder import get_instance
|
||||||
|
from homeassistant.components.recorder.db_schema import EventData, Events, EventTypes
|
||||||
|
from homeassistant.components.recorder.models import uuid_hex_to_bytes_or_none
|
||||||
|
from homeassistant.components.recorder.util import session_scope
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Time categories for usage patterns
|
||||||
|
TIME_CATEGORIES = ["morning", "afternoon", "evening", "night"]
|
||||||
|
|
||||||
|
RESULTS_TO_SHOW = 8
|
||||||
|
|
||||||
|
|
||||||
|
def time_category(dt: datetime) -> str:
|
||||||
|
"""Determine the time category for a given datetime."""
|
||||||
|
hour = dt.hour
|
||||||
|
if 6 <= hour < 12:
|
||||||
|
return "morning"
|
||||||
|
if 12 <= hour < 18:
|
||||||
|
return "afternoon"
|
||||||
|
if 18 <= hour < 22:
|
||||||
|
return "evening"
|
||||||
|
return "night"
|
||||||
|
|
||||||
|
|
||||||
|
async def async_predict_common_control(
|
||||||
|
hass: HomeAssistant, user_id: str
|
||||||
|
) -> dict[str, list[str]]:
|
||||||
|
"""Generate a list of commonly used entities for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hass: Home Assistant instance
|
||||||
|
user_id: User ID to filter events by.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with time categories as keys and lists of most common entity IDs as values
|
||||||
|
"""
|
||||||
|
# Get the recorder instance to ensure it's ready
|
||||||
|
recorder = get_instance(hass)
|
||||||
|
|
||||||
|
# Execute the database operation in the recorder's executor
|
||||||
|
return await recorder.async_add_executor_job(
|
||||||
|
_fetch_with_session, hass, _fetch_and_process_data, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[str]]:
|
||||||
|
"""Fetch and process service call events from the database."""
|
||||||
|
# Prepare a dictionary to track results
|
||||||
|
results: OrderedDict[str, Counter[str]] = OrderedDict(
|
||||||
|
(time_cat, Counter()) for time_cat in TIME_CATEGORIES
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep track of contexts that we processed so that we will only process
|
||||||
|
# the first service call in a context, and not subsequent calls.
|
||||||
|
context_processed = set()
|
||||||
|
|
||||||
|
# Build the query to get call_service events
|
||||||
|
# First, get the event_type_id for 'call_service'
|
||||||
|
event_type_query = select(EventTypes.event_type_id).where(
|
||||||
|
EventTypes.event_type == "call_service"
|
||||||
|
)
|
||||||
|
event_type_result = session.execute(event_type_query).first()
|
||||||
|
|
||||||
|
if not event_type_result:
|
||||||
|
_LOGGER.warning("No call_service events found in database")
|
||||||
|
return {time_cat: [] for time_cat in TIME_CATEGORIES}
|
||||||
|
|
||||||
|
call_service_type_id = event_type_result[0]
|
||||||
|
thirty_days_ago_ts = (dt_util.utcnow() - timedelta(days=30)).timestamp()
|
||||||
|
user_id_bytes = uuid_hex_to_bytes_or_none(user_id)
|
||||||
|
if not user_id_bytes:
|
||||||
|
raise ValueError("Invalid user_id format")
|
||||||
|
|
||||||
|
# Build the main query for events with their data
|
||||||
|
query = (
|
||||||
|
select(
|
||||||
|
Events.context_id_bin,
|
||||||
|
Events.context_user_id_bin,
|
||||||
|
Events.time_fired_ts,
|
||||||
|
EventData.shared_data,
|
||||||
|
)
|
||||||
|
.select_from(Events)
|
||||||
|
.outerjoin(EventData, Events.data_id == EventData.data_id)
|
||||||
|
.where(Events.event_type_id == call_service_type_id)
|
||||||
|
.where(Events.time_fired_ts >= thirty_days_ago_ts)
|
||||||
|
.where(Events.context_user_id_bin == user_id_bytes)
|
||||||
|
.order_by(Events.time_fired_ts)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the query
|
||||||
|
for row in session.execute(query):
|
||||||
|
context_id = row.context_id_bin
|
||||||
|
shared_data = row.shared_data
|
||||||
|
time_fired_ts = row.time_fired_ts
|
||||||
|
|
||||||
|
# Skip if we have already processed an event that was part of this context
|
||||||
|
if context_id in context_processed:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse the event data
|
||||||
|
if not shared_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
event_data = json.loads(shared_data)
|
||||||
|
except (ValueError, TypeError) as err:
|
||||||
|
_LOGGER.debug("Failed to parse event data: %s", err)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Empty event data, skipping
|
||||||
|
if not event_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
service_data = event_data.get("service_data")
|
||||||
|
|
||||||
|
# No service data found, skipping
|
||||||
|
if not service_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (target := service_data.get("target")) and (
|
||||||
|
target_entity_ids := target.get("entity_id")
|
||||||
|
):
|
||||||
|
entity_ids = target_entity_ids
|
||||||
|
else:
|
||||||
|
entity_ids = service_data.get("entity_id")
|
||||||
|
|
||||||
|
# No entity IDs found, skip this event
|
||||||
|
if entity_ids is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(entity_ids, list):
|
||||||
|
entity_ids = [entity_ids]
|
||||||
|
|
||||||
|
# Mark this context as processed
|
||||||
|
context_processed.add(context_id)
|
||||||
|
|
||||||
|
# Convert timestamp to datetime and determine time category
|
||||||
|
if time_fired_ts:
|
||||||
|
time_fired = dt_util.utc_from_timestamp(time_fired_ts)
|
||||||
|
# Convert to local time for time category determination
|
||||||
|
local_time = dt_util.as_local(time_fired)
|
||||||
|
period = time_category(local_time)
|
||||||
|
|
||||||
|
# Count entity usage
|
||||||
|
for entity_id in entity_ids:
|
||||||
|
results[period][entity_id] += 1
|
||||||
|
|
||||||
|
# Convert results to lists of top entities
|
||||||
|
final_results = {}
|
||||||
|
|
||||||
|
for period, period_results in results.items():
|
||||||
|
entities = [
|
||||||
|
ent_id for (ent_id, _count) in period_results.most_common(RESULTS_TO_SHOW)
|
||||||
|
]
|
||||||
|
final_results[period] = entities
|
||||||
|
|
||||||
|
return final_results
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_with_session(
|
||||||
|
hass: HomeAssistant, fetch_func: Callable[[Session], dict[str, list[str]]], *args
|
||||||
|
) -> dict[str, list[str]]:
|
||||||
|
"""Execute a fetch function with a database session."""
|
||||||
|
with session_scope(hass=hass, read_only=True) as session:
|
||||||
|
return fetch_func(session, *args)
|
3
homeassistant/components/usage_prediction/const.py
Normal file
3
homeassistant/components/usage_prediction/const.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Constants for the usage prediction integration."""
|
||||||
|
|
||||||
|
DOMAIN = "usage_prediction"
|
10
homeassistant/components/usage_prediction/manifest.json
Normal file
10
homeassistant/components/usage_prediction/manifest.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"domain": "usage_prediction",
|
||||||
|
"name": "Usage Prediction",
|
||||||
|
"codeowners": ["@home-assistant/core"],
|
||||||
|
"dependencies": ["http", "recorder"],
|
||||||
|
"documentation": "https://www.home-assistant.io/integrations/usage_prediction",
|
||||||
|
"integration_type": "system",
|
||||||
|
"iot_class": "calculated",
|
||||||
|
"quality_scale": "internal"
|
||||||
|
}
|
20
homeassistant/components/usage_prediction/quality_scale.yaml
Normal file
20
homeassistant/components/usage_prediction/quality_scale.yaml
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
rules:
|
||||||
|
# Bronze
|
||||||
|
action-setup: exempt
|
||||||
|
appropriate-polling: exempt
|
||||||
|
brands: todo
|
||||||
|
common-modules: done
|
||||||
|
config-flow-test-coverage: exempt
|
||||||
|
config-flow: exempt
|
||||||
|
dependency-transparency: done
|
||||||
|
docs-actions: exempt
|
||||||
|
docs-high-level-description: todo
|
||||||
|
docs-installation-instructions: exempt
|
||||||
|
docs-removal-instructions: exempt
|
||||||
|
entity-event-setup: exempt
|
||||||
|
entity-unique-id: exempt
|
||||||
|
has-entity-name: exempt
|
||||||
|
runtime-data: exempt
|
||||||
|
test-before-configure: exempt
|
||||||
|
test-before-setup: exempt
|
||||||
|
unique-config-entry: exempt
|
3
homeassistant/components/usage_prediction/strings.json
Normal file
3
homeassistant/components/usage_prediction/strings.json
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"title": "Usage Prediction"
|
||||||
|
}
|
1
tests/components/usage_prediction/__init__.py
Normal file
1
tests/components/usage_prediction/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for the usage_prediction integration."""
|
354
tests/components/usage_prediction/test_common_control.py
Normal file
354
tests/components/usage_prediction/test_common_control.py
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
"""Test the common control usage prediction."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from freezegun import freeze_time
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components.usage_prediction.common_control import (
|
||||||
|
async_predict_common_control,
|
||||||
|
)
|
||||||
|
from homeassistant.const import EVENT_CALL_SERVICE
|
||||||
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
|
||||||
|
from tests.components.recorder.common import async_wait_recording_done
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("recorder_mock")
|
||||||
|
async def test_empty_database(hass: HomeAssistant) -> None:
|
||||||
|
"""Test function with empty database returns empty results."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Call the function with empty database
|
||||||
|
results = await async_predict_common_control(hass, user_id)
|
||||||
|
|
||||||
|
# Should return empty lists for all time categories
|
||||||
|
assert results == {
|
||||||
|
"morning": [],
|
||||||
|
"afternoon": [],
|
||||||
|
"evening": [],
|
||||||
|
"night": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("recorder_mock")
|
||||||
|
async def test_invalid_user_id(hass: HomeAssistant) -> None:
|
||||||
|
"""Test function with invalid user ID returns empty results."""
|
||||||
|
# Invalid user ID format (not a valid UUID)
|
||||||
|
results = await async_predict_common_control(hass, "invalid-user-id")
|
||||||
|
|
||||||
|
# Should return empty lists for all time categories due to invalid user ID
|
||||||
|
assert results == {
|
||||||
|
"morning": [],
|
||||||
|
"afternoon": [],
|
||||||
|
"evening": [],
|
||||||
|
"night": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("recorder_mock")
|
||||||
|
async def test_with_service_calls(hass: HomeAssistant) -> None:
|
||||||
|
"""Test function with actual service call events in database."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create service call events at different times of day
|
||||||
|
# Morning events - use separate service calls to get around context deduplication
|
||||||
|
with freeze_time("2023-07-01 07:00:00+00:00"): # Morning
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_CALL_SERVICE,
|
||||||
|
{
|
||||||
|
"domain": "light",
|
||||||
|
"service": "turn_on",
|
||||||
|
"service_data": {"entity_id": ["light.living_room", "light.kitchen"]},
|
||||||
|
},
|
||||||
|
context=Context(user_id=user_id),
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Afternoon events
|
||||||
|
with freeze_time("2023-07-01 14:00:00+00:00"): # Afternoon
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_CALL_SERVICE,
|
||||||
|
{
|
||||||
|
"domain": "climate",
|
||||||
|
"service": "set_temperature",
|
||||||
|
"service_data": {"entity_id": "climate.thermostat"},
|
||||||
|
},
|
||||||
|
context=Context(user_id=user_id),
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Evening events
|
||||||
|
with freeze_time("2023-07-01 19:00:00+00:00"): # Evening
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_CALL_SERVICE,
|
||||||
|
{
|
||||||
|
"domain": "light",
|
||||||
|
"service": "turn_off",
|
||||||
|
"service_data": {"entity_id": "light.bedroom"},
|
||||||
|
},
|
||||||
|
context=Context(user_id=user_id),
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Night events
|
||||||
|
with freeze_time("2023-07-01 23:00:00+00:00"): # Night
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_CALL_SERVICE,
|
||||||
|
{
|
||||||
|
"domain": "lock",
|
||||||
|
"service": "lock",
|
||||||
|
"service_data": {"entity_id": "lock.front_door"},
|
||||||
|
},
|
||||||
|
context=Context(user_id=user_id),
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Wait for events to be recorded
|
||||||
|
await async_wait_recording_done(hass)
|
||||||
|
|
||||||
|
# Get predictions - make sure we're still in a reasonable timeframe
|
||||||
|
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
|
||||||
|
results = await async_predict_common_control(hass, user_id)
|
||||||
|
|
||||||
|
# Verify results contain the expected entities in the correct time periods
|
||||||
|
assert results == {
|
||||||
|
"morning": ["climate.thermostat"],
|
||||||
|
"afternoon": ["light.bedroom", "lock.front_door"],
|
||||||
|
"evening": [],
|
||||||
|
"night": ["light.living_room", "light.kitchen"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("recorder_mock")
|
||||||
|
async def test_multiple_entities_in_one_call(hass: HomeAssistant) -> None:
|
||||||
|
"""Test handling of service calls with multiple entity IDs."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
with freeze_time("2023-07-01 10:00:00+00:00"): # Morning
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_CALL_SERVICE,
|
||||||
|
{
|
||||||
|
"domain": "light",
|
||||||
|
"service": "turn_on",
|
||||||
|
"service_data": {
|
||||||
|
"entity_id": ["light.living_room", "light.kitchen", "light.hallway"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
context=Context(user_id=user_id),
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
await async_wait_recording_done(hass)
|
||||||
|
|
||||||
|
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
|
||||||
|
results = await async_predict_common_control(hass, user_id)
|
||||||
|
|
||||||
|
# All three lights should be counted (10:00 UTC = 02:00 local = night)
|
||||||
|
assert results["night"] == ["light.living_room", "light.kitchen", "light.hallway"]
|
||||||
|
assert results["morning"] == []
|
||||||
|
assert results["afternoon"] == []
|
||||||
|
assert results["evening"] == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("recorder_mock")
|
||||||
|
async def test_context_deduplication(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that multiple events with the same context are deduplicated."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
context = Context(user_id=user_id)
|
||||||
|
|
||||||
|
with freeze_time("2023-07-01 10:00:00+00:00"): # Morning
|
||||||
|
# Fire multiple events with the same context
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_CALL_SERVICE,
|
||||||
|
{
|
||||||
|
"domain": "light",
|
||||||
|
"service": "turn_on",
|
||||||
|
"service_data": {"entity_id": "light.living_room"},
|
||||||
|
},
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_CALL_SERVICE,
|
||||||
|
{
|
||||||
|
"domain": "switch",
|
||||||
|
"service": "turn_on",
|
||||||
|
"service_data": {"entity_id": "switch.coffee_maker"},
|
||||||
|
},
|
||||||
|
context=context, # Same context
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
await async_wait_recording_done(hass)
|
||||||
|
|
||||||
|
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
|
||||||
|
results = await async_predict_common_control(hass, user_id)
|
||||||
|
|
||||||
|
# Only the first event should be processed (10:00 UTC = 02:00 local = night)
|
||||||
|
assert results == {
|
||||||
|
"morning": [],
|
||||||
|
"afternoon": [],
|
||||||
|
"evening": [],
|
||||||
|
"night": ["light.living_room"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("recorder_mock")
|
||||||
|
async def test_old_events_excluded(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that events older than 30 days are excluded."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create an old event (35 days ago)
|
||||||
|
with freeze_time("2023-05-27 10:00:00+00:00"): # 35 days before July 1st
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_CALL_SERVICE,
|
||||||
|
{
|
||||||
|
"domain": "light",
|
||||||
|
"service": "turn_on",
|
||||||
|
"service_data": {"entity_id": "light.old_event"},
|
||||||
|
},
|
||||||
|
context=Context(user_id=user_id),
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Create a recent event (5 days ago)
|
||||||
|
with freeze_time("2023-06-26 10:00:00+00:00"): # 5 days before July 1st
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_CALL_SERVICE,
|
||||||
|
{
|
||||||
|
"domain": "light",
|
||||||
|
"service": "turn_on",
|
||||||
|
"service_data": {"entity_id": "light.recent_event"},
|
||||||
|
},
|
||||||
|
context=Context(user_id=user_id),
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
await async_wait_recording_done(hass)
|
||||||
|
|
||||||
|
# Query with current time
|
||||||
|
with freeze_time("2023-07-01 10:00:00+00:00"):
|
||||||
|
results = await async_predict_common_control(hass, user_id)
|
||||||
|
|
||||||
|
# Only recent event should be included (10:00 UTC = 02:00 local = night)
|
||||||
|
assert results == {
|
||||||
|
"morning": [],
|
||||||
|
"afternoon": [],
|
||||||
|
"evening": [],
|
||||||
|
"night": ["light.recent_event"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("recorder_mock")
|
||||||
|
async def test_entities_limit(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that only top entities are returned per time category."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create more than 5 different entities in morning
|
||||||
|
with freeze_time("2023-07-01 08:00:00+00:00"):
|
||||||
|
# Create entities with different frequencies
|
||||||
|
entities_with_counts = [
|
||||||
|
("light.most_used", 10),
|
||||||
|
("light.second", 8),
|
||||||
|
("light.third", 6),
|
||||||
|
("light.fourth", 4),
|
||||||
|
("light.fifth", 2),
|
||||||
|
("light.sixth", 1),
|
||||||
|
("light.seventh", 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
for entity_id, count in entities_with_counts:
|
||||||
|
for _ in range(count):
|
||||||
|
# Use different context for each call
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_CALL_SERVICE,
|
||||||
|
{
|
||||||
|
"domain": "light",
|
||||||
|
"service": "toggle",
|
||||||
|
"service_data": {"entity_id": entity_id},
|
||||||
|
},
|
||||||
|
context=Context(user_id=user_id),
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
await async_wait_recording_done(hass)
|
||||||
|
|
||||||
|
with (
|
||||||
|
freeze_time("2023-07-02 10:00:00+00:00"),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.usage_prediction.common_control.RESULTS_TO_SHOW",
|
||||||
|
5,
|
||||||
|
),
|
||||||
|
): # Next day, so events are recent
|
||||||
|
results = await async_predict_common_control(hass, user_id)
|
||||||
|
|
||||||
|
# Should be the top 5 most used (08:00 UTC = 00:00 local = night)
|
||||||
|
assert results["night"] == [
|
||||||
|
"light.most_used",
|
||||||
|
"light.second",
|
||||||
|
"light.third",
|
||||||
|
"light.fourth",
|
||||||
|
"light.fifth",
|
||||||
|
]
|
||||||
|
assert results["morning"] == []
|
||||||
|
assert results["afternoon"] == []
|
||||||
|
assert results["evening"] == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("recorder_mock")
|
||||||
|
async def test_different_users_separated(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that events from different users are properly separated."""
|
||||||
|
user_id_1 = str(uuid.uuid4())
|
||||||
|
user_id_2 = str(uuid.uuid4())
|
||||||
|
|
||||||
|
with freeze_time("2023-07-01 10:00:00+00:00"):
|
||||||
|
# User 1 events
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_CALL_SERVICE,
|
||||||
|
{
|
||||||
|
"domain": "light",
|
||||||
|
"service": "turn_on",
|
||||||
|
"service_data": {"entity_id": "light.user1_light"},
|
||||||
|
},
|
||||||
|
context=Context(user_id=user_id_1),
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# User 2 events
|
||||||
|
hass.bus.async_fire(
|
||||||
|
EVENT_CALL_SERVICE,
|
||||||
|
{
|
||||||
|
"domain": "light",
|
||||||
|
"service": "turn_on",
|
||||||
|
"service_data": {"entity_id": "light.user2_light"},
|
||||||
|
},
|
||||||
|
context=Context(user_id=user_id_2),
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
await async_wait_recording_done(hass)
|
||||||
|
|
||||||
|
# Get results for each user
|
||||||
|
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
|
||||||
|
results_user1 = await async_predict_common_control(hass, user_id_1)
|
||||||
|
results_user2 = await async_predict_common_control(hass, user_id_2)
|
||||||
|
|
||||||
|
# Each user should only see their own entities (10:00 UTC = 02:00 local = night)
|
||||||
|
assert results_user1 == {
|
||||||
|
"morning": [],
|
||||||
|
"afternoon": [],
|
||||||
|
"evening": [],
|
||||||
|
"night": ["light.user1_light"],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert results_user2 == {
|
||||||
|
"morning": [],
|
||||||
|
"afternoon": [],
|
||||||
|
"evening": [],
|
||||||
|
"night": ["light.user2_light"],
|
||||||
|
}
|
105
tests/components/usage_prediction/test_websocket.py
Normal file
105
tests/components/usage_prediction/test_websocket.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""Test usage_prediction WebSocket API."""
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
from datetime import timedelta
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from freezegun import freeze_time
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
|
from tests.common import MockUser
|
||||||
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_predict_common_control() -> Generator[Mock]:
|
||||||
|
"""Return a mock result for common control."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.usage_prediction.common_control.async_predict_common_control",
|
||||||
|
return_value={
|
||||||
|
"morning": ["light.kitchen"],
|
||||||
|
"afternoon": ["climate.thermostat"],
|
||||||
|
"evening": ["light.bedroom"],
|
||||||
|
"night": ["lock.front_door"],
|
||||||
|
},
|
||||||
|
) as mock_predict:
|
||||||
|
yield mock_predict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("recorder_mock")
|
||||||
|
async def test_common_control(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
hass_admin_user: MockUser,
|
||||||
|
mock_predict_common_control: Mock,
|
||||||
|
) -> None:
|
||||||
|
"""Test usage_prediction common control WebSocket command."""
|
||||||
|
assert await async_setup_component(hass, "usage_prediction", {})
|
||||||
|
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
await client.send_json({"id": 1, "type": "usage_prediction/common_control"})
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["id"] == 1
|
||||||
|
assert msg["type"] == "result"
|
||||||
|
assert msg["success"] is True
|
||||||
|
assert msg["result"] == {
|
||||||
|
"entities": [
|
||||||
|
"light.kitchen",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert mock_predict_common_control.call_count == 1
|
||||||
|
assert mock_predict_common_control.mock_calls[0][1][1] == hass_admin_user.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("recorder_mock")
|
||||||
|
async def test_caching_behavior(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
mock_predict_common_control: Mock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that results are cached for 24 hours."""
|
||||||
|
assert await async_setup_component(hass, "usage_prediction", {})
|
||||||
|
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
now = dt_util.utcnow()
|
||||||
|
|
||||||
|
# First call should fetch from database
|
||||||
|
with freeze_time(now):
|
||||||
|
await client.send_json({"id": 1, "type": "usage_prediction/common_control"})
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"] is True
|
||||||
|
assert msg["result"] == {
|
||||||
|
"entities": [
|
||||||
|
"light.kitchen",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
assert mock_predict_common_control.call_count == 1
|
||||||
|
|
||||||
|
mock_predict_common_control.return_value["morning"].append("light.bla")
|
||||||
|
|
||||||
|
# Second call within 24 hours should use cache
|
||||||
|
with freeze_time(now + timedelta(hours=23)):
|
||||||
|
await client.send_json({"id": 2, "type": "usage_prediction/common_control"})
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"] is True
|
||||||
|
assert msg["result"] == {
|
||||||
|
"entities": [
|
||||||
|
"light.kitchen",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
# Should still be 1 (no new database call)
|
||||||
|
assert mock_predict_common_control.call_count == 1
|
||||||
|
|
||||||
|
# Third call after 24 hours should fetch from database again
|
||||||
|
with freeze_time(now + timedelta(hours=25)):
|
||||||
|
await client.send_json({"id": 3, "type": "usage_prediction/common_control"})
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"] is True
|
||||||
|
assert msg["result"] == {"entities": ["light.kitchen", "light.bla"]}
|
||||||
|
# Should now be 2 (new database call)
|
||||||
|
assert mock_predict_common_control.call_count == 2
|
Reference in New Issue
Block a user