diff --git a/homeassistant/components/usage_prediction/__init__.py b/homeassistant/components/usage_prediction/__init__.py new file mode 100644 index 00000000000..a6688139d4a --- /dev/null +++ b/homeassistant/components/usage_prediction/__init__.py @@ -0,0 +1,101 @@ +"""The usage prediction integration.""" + +from __future__ import annotations + +from datetime import timedelta +from typing import Any + +from homeassistant.components import websocket_api +from homeassistant.core import HomeAssistant +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.storage import Store +from homeassistant.helpers.typing import ConfigType +from homeassistant.util import dt as dt_util + +from . import common_control +from .const import DOMAIN + +CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) + +# Storage configuration +STORAGE_VERSION = 1 +STORAGE_KEY_PREFIX = f"{DOMAIN}.common_control" +CACHE_DURATION = timedelta(hours=24) + + +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + """Set up the usage prediction integration.""" + websocket_api.async_register_command(hass, ws_common_control) + + # Initialize domain data storage + hass.data[DOMAIN] = {} + + return True + + +@websocket_api.websocket_command( + { + "type": f"{DOMAIN}/common_control", + "user_id": cv.string, + } +) +@websocket_api.async_response +async def ws_common_control( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Handle usage prediction common control WebSocket API.""" + result = await get_cached_common_control(hass, connection.user.id) + time_category = common_control.time_category(dt_util.now()) + connection.send_result( + msg["id"], + { + "entities": result[time_category], + }, + ) + + +async def get_cached_common_control( + hass: HomeAssistant, user_id: str +) -> dict[str, list[str]]: + """Get cached common control predictions or fetch new ones. + + Returns cached data if it's less than 24 hours old, + otherwise fetches new data and caches it. + """ + # Create a unique storage key for this user + storage_key = f"{STORAGE_KEY_PREFIX}.{user_id}" + + # Get or create store for this user + if storage_key not in hass.data[DOMAIN]: + hass.data[DOMAIN][storage_key] = Store[dict[str, Any]]( + hass, STORAGE_VERSION, storage_key, private=True + ) + + store: Store[dict[str, Any]] = hass.data[DOMAIN][storage_key] + + # Load cached data + cached_data = await store.async_load() + + # Check if cache is valid (less than 24 hours old) + now = dt_util.utcnow() + if cached_data is not None: + cached_time = dt_util.parse_datetime(cached_data.get("timestamp", "")) + if cached_time and (now - cached_time) < CACHE_DURATION: + # Cache is still valid, return the cached predictions + return cached_data["predictions"] + + # Cache is expired or doesn't exist, fetch new data + predictions = await common_control.async_predict_common_control(hass, user_id) + + # Store the new data with timestamp + cache_data = { + "timestamp": now.isoformat(), + "predictions": predictions, + } + + # Save to cache + await store.async_save(cache_data) + + return predictions diff --git a/homeassistant/components/usage_prediction/common_control.py b/homeassistant/components/usage_prediction/common_control.py new file mode 100644 index 00000000000..b110187cc0a --- /dev/null +++ b/homeassistant/components/usage_prediction/common_control.py @@ -0,0 +1,181 @@ +"""Code to generate common control usage patterns.""" + +from __future__ import annotations + +from collections import Counter, OrderedDict +from collections.abc import Callable +from datetime import datetime, timedelta +import json +import logging + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from homeassistant.components.recorder import get_instance +from homeassistant.components.recorder.db_schema import EventData, Events, EventTypes +from homeassistant.components.recorder.models import uuid_hex_to_bytes_or_none +from homeassistant.components.recorder.util import session_scope +from homeassistant.core import HomeAssistant +from homeassistant.util import dt as dt_util + +_LOGGER = logging.getLogger(__name__) + +# Time categories for usage patterns +TIME_CATEGORIES = ["morning", "afternoon", "evening", "night"] + +RESULTS_TO_SHOW = 8 + + +def time_category(dt: datetime) -> str: + """Determine the time category for a given datetime.""" + hour = dt.hour + if 6 <= hour < 12: + return "morning" + if 12 <= hour < 18: + return "afternoon" + if 18 <= hour < 22: + return "evening" + return "night" + + +async def async_predict_common_control( + hass: HomeAssistant, user_id: str +) -> dict[str, list[str]]: + """Generate a list of commonly used entities for a user. + + Args: + hass: Home Assistant instance + user_id: User ID to filter events by. + + Returns: + Dictionary with time categories as keys and lists of most common entity IDs as values + """ + # Get the recorder instance to ensure it's ready + recorder = get_instance(hass) + + # Execute the database operation in the recorder's executor + return await recorder.async_add_executor_job( + _fetch_with_session, hass, _fetch_and_process_data, user_id + ) + + +def _fetch_and_process_data(session: Session, user_id: str) -> dict[str, list[str]]: + """Fetch and process service call events from the database.""" + # Prepare a dictionary to track results + results: OrderedDict[str, Counter[str]] = OrderedDict( + (time_cat, Counter()) for time_cat in TIME_CATEGORIES + ) + + # Keep track of contexts that we processed so that we will only process + # the first service call in a context, and not subsequent calls. + context_processed = set() + + # Build the query to get call_service events + # First, get the event_type_id for 'call_service' + event_type_query = select(EventTypes.event_type_id).where( + EventTypes.event_type == "call_service" + ) + event_type_result = session.execute(event_type_query).first() + + if not event_type_result: + _LOGGER.warning("No call_service events found in database") + return {time_cat: [] for time_cat in TIME_CATEGORIES} + + call_service_type_id = event_type_result[0] + thirty_days_ago_ts = (dt_util.utcnow() - timedelta(days=30)).timestamp() + user_id_bytes = uuid_hex_to_bytes_or_none(user_id) + if not user_id_bytes: + raise ValueError("Invalid user_id format") + + # Build the main query for events with their data + query = ( + select( + Events.context_id_bin, + Events.context_user_id_bin, + Events.time_fired_ts, + EventData.shared_data, + ) + .select_from(Events) + .outerjoin(EventData, Events.data_id == EventData.data_id) + .where(Events.event_type_id == call_service_type_id) + .where(Events.time_fired_ts >= thirty_days_ago_ts) + .where(Events.context_user_id_bin == user_id_bytes) + .order_by(Events.time_fired_ts) + ) + + # Execute the query + for row in session.execute(query): + context_id = row.context_id_bin + shared_data = row.shared_data + time_fired_ts = row.time_fired_ts + + # Skip if we have already processed an event that was part of this context + if context_id in context_processed: + continue + + # Parse the event data + if not shared_data: + continue + + try: + event_data = json.loads(shared_data) + except (ValueError, TypeError) as err: + _LOGGER.debug("Failed to parse event data: %s", err) + continue + + # Empty event data, skipping + if not event_data: + continue + + service_data = event_data.get("service_data") + + # No service data found, skipping + if not service_data: + continue + + if (target := service_data.get("target")) and ( + target_entity_ids := target.get("entity_id") + ): + entity_ids = target_entity_ids + else: + entity_ids = service_data.get("entity_id") + + # No entity IDs found, skip this event + if entity_ids is None: + continue + + if not isinstance(entity_ids, list): + entity_ids = [entity_ids] + + # Mark this context as processed + context_processed.add(context_id) + + # Convert timestamp to datetime and determine time category + if time_fired_ts: + time_fired = dt_util.utc_from_timestamp(time_fired_ts) + # Convert to local time for time category determination + local_time = dt_util.as_local(time_fired) + period = time_category(local_time) + + # Count entity usage + for entity_id in entity_ids: + results[period][entity_id] += 1 + + # Convert results to lists of top entities + final_results = {} + + for period, period_results in results.items(): + entities = [ + ent_id for (ent_id, _count) in period_results.most_common(RESULTS_TO_SHOW) + ] + final_results[period] = entities + + return final_results + + +def _fetch_with_session( + hass: HomeAssistant, fetch_func: Callable[[Session], dict[str, list[str]]], *args +) -> dict[str, list[str]]: + """Execute a fetch function with a database session.""" + with session_scope(hass=hass, read_only=True) as session: + return fetch_func(session, *args) diff --git a/homeassistant/components/usage_prediction/const.py b/homeassistant/components/usage_prediction/const.py new file mode 100644 index 00000000000..109c276fa5c --- /dev/null +++ b/homeassistant/components/usage_prediction/const.py @@ -0,0 +1,3 @@ +"""Constants for the usage prediction integration.""" + +DOMAIN = "usage_prediction" diff --git a/homeassistant/components/usage_prediction/manifest.json b/homeassistant/components/usage_prediction/manifest.json new file mode 100644 index 00000000000..a1f4d4e7cf2 --- /dev/null +++ b/homeassistant/components/usage_prediction/manifest.json @@ -0,0 +1,10 @@ +{ + "domain": "usage_prediction", + "name": "Usage Prediction", + "codeowners": ["@home-assistant/core"], + "dependencies": ["http", "recorder"], + "documentation": "https://www.home-assistant.io/integrations/usage_prediction", + "integration_type": "system", + "iot_class": "calculated", + "quality_scale": "internal" +} diff --git a/homeassistant/components/usage_prediction/quality_scale.yaml b/homeassistant/components/usage_prediction/quality_scale.yaml new file mode 100644 index 00000000000..e6c3e5da1e3 --- /dev/null +++ b/homeassistant/components/usage_prediction/quality_scale.yaml @@ -0,0 +1,20 @@ +rules: + # Bronze + action-setup: exempt + appropriate-polling: exempt + brands: todo + common-modules: done + config-flow-test-coverage: exempt + config-flow: exempt + dependency-transparency: done + docs-actions: exempt + docs-high-level-description: todo + docs-installation-instructions: exempt + docs-removal-instructions: exempt + entity-event-setup: exempt + entity-unique-id: exempt + has-entity-name: exempt + runtime-data: exempt + test-before-configure: exempt + test-before-setup: exempt + unique-config-entry: exempt diff --git a/homeassistant/components/usage_prediction/strings.json b/homeassistant/components/usage_prediction/strings.json new file mode 100644 index 00000000000..56ab70d2360 --- /dev/null +++ b/homeassistant/components/usage_prediction/strings.json @@ -0,0 +1,3 @@ +{ + "title": "Usage Prediction" +} diff --git a/tests/components/usage_prediction/__init__.py b/tests/components/usage_prediction/__init__.py new file mode 100644 index 00000000000..124766b0c39 --- /dev/null +++ b/tests/components/usage_prediction/__init__.py @@ -0,0 +1 @@ +"""Tests for the usage_prediction integration.""" diff --git a/tests/components/usage_prediction/test_common_control.py b/tests/components/usage_prediction/test_common_control.py new file mode 100644 index 00000000000..b5d8a858344 --- /dev/null +++ b/tests/components/usage_prediction/test_common_control.py @@ -0,0 +1,354 @@ +"""Test the common control usage prediction.""" + +from __future__ import annotations + +from unittest.mock import patch +import uuid + +from freezegun import freeze_time +import pytest + +from homeassistant.components.usage_prediction.common_control import ( + async_predict_common_control, +) +from homeassistant.const import EVENT_CALL_SERVICE +from homeassistant.core import Context, HomeAssistant + +from tests.components.recorder.common import async_wait_recording_done + + +@pytest.mark.usefixtures("recorder_mock") +async def test_empty_database(hass: HomeAssistant) -> None: + """Test function with empty database returns empty results.""" + user_id = str(uuid.uuid4()) + + # Call the function with empty database + results = await async_predict_common_control(hass, user_id) + + # Should return empty lists for all time categories + assert results == { + "morning": [], + "afternoon": [], + "evening": [], + "night": [], + } + + +@pytest.mark.usefixtures("recorder_mock") +async def test_invalid_user_id(hass: HomeAssistant) -> None: + """Test function with invalid user ID returns empty results.""" + # Invalid user ID format (not a valid UUID) + results = await async_predict_common_control(hass, "invalid-user-id") + + # Should return empty lists for all time categories due to invalid user ID + assert results == { + "morning": [], + "afternoon": [], + "evening": [], + "night": [], + } + + +@pytest.mark.usefixtures("recorder_mock") +async def test_with_service_calls(hass: HomeAssistant) -> None: + """Test function with actual service call events in database.""" + user_id = str(uuid.uuid4()) + + # Create service call events at different times of day + # Morning events - use separate service calls to get around context deduplication + with freeze_time("2023-07-01 07:00:00+00:00"): # Morning + hass.bus.async_fire( + EVENT_CALL_SERVICE, + { + "domain": "light", + "service": "turn_on", + "service_data": {"entity_id": ["light.living_room", "light.kitchen"]}, + }, + context=Context(user_id=user_id), + ) + await hass.async_block_till_done() + + # Afternoon events + with freeze_time("2023-07-01 14:00:00+00:00"): # Afternoon + hass.bus.async_fire( + EVENT_CALL_SERVICE, + { + "domain": "climate", + "service": "set_temperature", + "service_data": {"entity_id": "climate.thermostat"}, + }, + context=Context(user_id=user_id), + ) + await hass.async_block_till_done() + + # Evening events + with freeze_time("2023-07-01 19:00:00+00:00"): # Evening + hass.bus.async_fire( + EVENT_CALL_SERVICE, + { + "domain": "light", + "service": "turn_off", + "service_data": {"entity_id": "light.bedroom"}, + }, + context=Context(user_id=user_id), + ) + await hass.async_block_till_done() + + # Night events + with freeze_time("2023-07-01 23:00:00+00:00"): # Night + hass.bus.async_fire( + EVENT_CALL_SERVICE, + { + "domain": "lock", + "service": "lock", + "service_data": {"entity_id": "lock.front_door"}, + }, + context=Context(user_id=user_id), + ) + await hass.async_block_till_done() + + # Wait for events to be recorded + await async_wait_recording_done(hass) + + # Get predictions - make sure we're still in a reasonable timeframe + with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent + results = await async_predict_common_control(hass, user_id) + + # Verify results contain the expected entities in the correct time periods + assert results == { + "morning": ["climate.thermostat"], + "afternoon": ["light.bedroom", "lock.front_door"], + "evening": [], + "night": ["light.living_room", "light.kitchen"], + } + + +@pytest.mark.usefixtures("recorder_mock") +async def test_multiple_entities_in_one_call(hass: HomeAssistant) -> None: + """Test handling of service calls with multiple entity IDs.""" + user_id = str(uuid.uuid4()) + + with freeze_time("2023-07-01 10:00:00+00:00"): # Morning + hass.bus.async_fire( + EVENT_CALL_SERVICE, + { + "domain": "light", + "service": "turn_on", + "service_data": { + "entity_id": ["light.living_room", "light.kitchen", "light.hallway"] + }, + }, + context=Context(user_id=user_id), + ) + await hass.async_block_till_done() + + await async_wait_recording_done(hass) + + with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent + results = await async_predict_common_control(hass, user_id) + + # All three lights should be counted (10:00 UTC = 02:00 local = night) + assert results["night"] == ["light.living_room", "light.kitchen", "light.hallway"] + assert results["morning"] == [] + assert results["afternoon"] == [] + assert results["evening"] == [] + + +@pytest.mark.usefixtures("recorder_mock") +async def test_context_deduplication(hass: HomeAssistant) -> None: + """Test that multiple events with the same context are deduplicated.""" + user_id = str(uuid.uuid4()) + context = Context(user_id=user_id) + + with freeze_time("2023-07-01 10:00:00+00:00"): # Morning + # Fire multiple events with the same context + hass.bus.async_fire( + EVENT_CALL_SERVICE, + { + "domain": "light", + "service": "turn_on", + "service_data": {"entity_id": "light.living_room"}, + }, + context=context, + ) + await hass.async_block_till_done() + + hass.bus.async_fire( + EVENT_CALL_SERVICE, + { + "domain": "switch", + "service": "turn_on", + "service_data": {"entity_id": "switch.coffee_maker"}, + }, + context=context, # Same context + ) + await hass.async_block_till_done() + + await async_wait_recording_done(hass) + + with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent + results = await async_predict_common_control(hass, user_id) + + # Only the first event should be processed (10:00 UTC = 02:00 local = night) + assert results == { + "morning": [], + "afternoon": [], + "evening": [], + "night": ["light.living_room"], + } + + +@pytest.mark.usefixtures("recorder_mock") +async def test_old_events_excluded(hass: HomeAssistant) -> None: + """Test that events older than 30 days are excluded.""" + user_id = str(uuid.uuid4()) + + # Create an old event (35 days ago) + with freeze_time("2023-05-27 10:00:00+00:00"): # 35 days before July 1st + hass.bus.async_fire( + EVENT_CALL_SERVICE, + { + "domain": "light", + "service": "turn_on", + "service_data": {"entity_id": "light.old_event"}, + }, + context=Context(user_id=user_id), + ) + await hass.async_block_till_done() + + # Create a recent event (5 days ago) + with freeze_time("2023-06-26 10:00:00+00:00"): # 5 days before July 1st + hass.bus.async_fire( + EVENT_CALL_SERVICE, + { + "domain": "light", + "service": "turn_on", + "service_data": {"entity_id": "light.recent_event"}, + }, + context=Context(user_id=user_id), + ) + await hass.async_block_till_done() + + await async_wait_recording_done(hass) + + # Query with current time + with freeze_time("2023-07-01 10:00:00+00:00"): + results = await async_predict_common_control(hass, user_id) + + # Only recent event should be included (10:00 UTC = 02:00 local = night) + assert results == { + "morning": [], + "afternoon": [], + "evening": [], + "night": ["light.recent_event"], + } + + +@pytest.mark.usefixtures("recorder_mock") +async def test_entities_limit(hass: HomeAssistant) -> None: + """Test that only top entities are returned per time category.""" + user_id = str(uuid.uuid4()) + + # Create more than 5 different entities in morning + with freeze_time("2023-07-01 08:00:00+00:00"): + # Create entities with different frequencies + entities_with_counts = [ + ("light.most_used", 10), + ("light.second", 8), + ("light.third", 6), + ("light.fourth", 4), + ("light.fifth", 2), + ("light.sixth", 1), + ("light.seventh", 1), + ] + + for entity_id, count in entities_with_counts: + for _ in range(count): + # Use different context for each call + hass.bus.async_fire( + EVENT_CALL_SERVICE, + { + "domain": "light", + "service": "toggle", + "service_data": {"entity_id": entity_id}, + }, + context=Context(user_id=user_id), + ) + await hass.async_block_till_done() + + await async_wait_recording_done(hass) + + with ( + freeze_time("2023-07-02 10:00:00+00:00"), + patch( + "homeassistant.components.usage_prediction.common_control.RESULTS_TO_SHOW", + 5, + ), + ): # Next day, so events are recent + results = await async_predict_common_control(hass, user_id) + + # Should be the top 5 most used (08:00 UTC = 00:00 local = night) + assert results["night"] == [ + "light.most_used", + "light.second", + "light.third", + "light.fourth", + "light.fifth", + ] + assert results["morning"] == [] + assert results["afternoon"] == [] + assert results["evening"] == [] + + +@pytest.mark.usefixtures("recorder_mock") +async def test_different_users_separated(hass: HomeAssistant) -> None: + """Test that events from different users are properly separated.""" + user_id_1 = str(uuid.uuid4()) + user_id_2 = str(uuid.uuid4()) + + with freeze_time("2023-07-01 10:00:00+00:00"): + # User 1 events + hass.bus.async_fire( + EVENT_CALL_SERVICE, + { + "domain": "light", + "service": "turn_on", + "service_data": {"entity_id": "light.user1_light"}, + }, + context=Context(user_id=user_id_1), + ) + await hass.async_block_till_done() + + # User 2 events + hass.bus.async_fire( + EVENT_CALL_SERVICE, + { + "domain": "light", + "service": "turn_on", + "service_data": {"entity_id": "light.user2_light"}, + }, + context=Context(user_id=user_id_2), + ) + await hass.async_block_till_done() + + await async_wait_recording_done(hass) + + # Get results for each user + with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent + results_user1 = await async_predict_common_control(hass, user_id_1) + results_user2 = await async_predict_common_control(hass, user_id_2) + + # Each user should only see their own entities (10:00 UTC = 02:00 local = night) + assert results_user1 == { + "morning": [], + "afternoon": [], + "evening": [], + "night": ["light.user1_light"], + } + + assert results_user2 == { + "morning": [], + "afternoon": [], + "evening": [], + "night": ["light.user2_light"], + } diff --git a/tests/components/usage_prediction/test_websocket.py b/tests/components/usage_prediction/test_websocket.py new file mode 100644 index 00000000000..6e6f8aed79f --- /dev/null +++ b/tests/components/usage_prediction/test_websocket.py @@ -0,0 +1,105 @@ +"""Test usage_prediction WebSocket API.""" + +from collections.abc import Generator +from datetime import timedelta +from unittest.mock import Mock, patch + +from freezegun import freeze_time +import pytest + +from homeassistant.core import HomeAssistant +from homeassistant.setup import async_setup_component +from homeassistant.util import dt as dt_util + +from tests.common import MockUser +from tests.typing import WebSocketGenerator + + +@pytest.fixture +def mock_predict_common_control() -> Generator[Mock]: + """Return a mock result for common control.""" + with patch( + "homeassistant.components.usage_prediction.common_control.async_predict_common_control", + return_value={ + "morning": ["light.kitchen"], + "afternoon": ["climate.thermostat"], + "evening": ["light.bedroom"], + "night": ["lock.front_door"], + }, + ) as mock_predict: + yield mock_predict + + +@pytest.mark.usefixtures("recorder_mock") +async def test_common_control( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + hass_admin_user: MockUser, + mock_predict_common_control: Mock, +) -> None: + """Test usage_prediction common control WebSocket command.""" + assert await async_setup_component(hass, "usage_prediction", {}) + + client = await hass_ws_client(hass) + await client.send_json({"id": 1, "type": "usage_prediction/common_control"}) + + msg = await client.receive_json() + assert msg["id"] == 1 + assert msg["type"] == "result" + assert msg["success"] is True + assert msg["result"] == { + "entities": [ + "light.kitchen", + ] + } + assert mock_predict_common_control.call_count == 1 + assert mock_predict_common_control.mock_calls[0][1][1] == hass_admin_user.id + + +@pytest.mark.usefixtures("recorder_mock") +async def test_caching_behavior( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + mock_predict_common_control: Mock, +) -> None: + """Test that results are cached for 24 hours.""" + assert await async_setup_component(hass, "usage_prediction", {}) + + client = await hass_ws_client(hass) + now = dt_util.utcnow() + + # First call should fetch from database + with freeze_time(now): + await client.send_json({"id": 1, "type": "usage_prediction/common_control"}) + msg = await client.receive_json() + assert msg["success"] is True + assert msg["result"] == { + "entities": [ + "light.kitchen", + ] + } + assert mock_predict_common_control.call_count == 1 + + mock_predict_common_control.return_value["morning"].append("light.bla") + + # Second call within 24 hours should use cache + with freeze_time(now + timedelta(hours=23)): + await client.send_json({"id": 2, "type": "usage_prediction/common_control"}) + msg = await client.receive_json() + assert msg["success"] is True + assert msg["result"] == { + "entities": [ + "light.kitchen", + ] + } + # Should still be 1 (no new database call) + assert mock_predict_common_control.call_count == 1 + + # Third call after 24 hours should fetch from database again + with freeze_time(now + timedelta(hours=25)): + await client.send_json({"id": 3, "type": "usage_prediction/common_control"}) + msg = await client.receive_json() + assert msg["success"] is True + assert msg["result"] == {"entities": ["light.kitchen", "light.bla"]} + # Should now be 2 (new database call) + assert mock_predict_common_control.call_count == 2