diff --git a/homeassistant/components/usage_prediction/__init__.py b/homeassistant/components/usage_prediction/__init__.py index c95e73e4351..7d9c1083c02 100644 --- a/homeassistant/components/usage_prediction/__init__.py +++ b/homeassistant/components/usage_prediction/__init__.py @@ -67,21 +67,17 @@ async def get_cached_common_control( # Create a unique storage key for this user storage_key = f"{STORAGE_KEY_PREFIX}.{user_id}" - # Get or create store for this user if storage_key not in hass.data[DOMAIN]: - hass.data[DOMAIN][storage_key] = Store[dict[str, Any]]( - hass, STORAGE_VERSION, storage_key, private=True - ) + store = Store[dict[str, Any]](hass, STORAGE_VERSION, storage_key) + cached_data: dict[str, Any] | None = await store.async_load() + hass.data[DOMAIN][storage_key] = cached_data - store: Store[dict[str, Any]] = hass.data[DOMAIN][storage_key] - - # Load cached data - cached_data = await store.async_load() + cached_data = hass.data[DOMAIN].get(storage_key) # Check if cache is valid (less than 24 hours old) now = dt_util.utcnow() if cached_data is not None: - cached_time = dt_util.parse_datetime(cached_data.get("timestamp", "")) + cached_time = dt_util.parse_datetime(cached_data["timestamp"]) if cached_time and (now - cached_time) < CACHE_DURATION: # Cache is still valid, return the cached predictions return cached_data["predictions"] @@ -90,12 +86,14 @@ async def get_cached_common_control( predictions = await common_control.async_predict_common_control(hass, user_id) # Store the new data with timestamp - cache_data = { + cached_data = { "timestamp": now.isoformat(), "predictions": predictions, } # Save to cache - await store.async_save(cache_data) + store = Store[dict[str, Any]](hass, STORAGE_VERSION, storage_key) + store.async_delay_save(lambda: cached_data) + hass.data[DOMAIN][storage_key] = cached_data return predictions diff --git a/tests/components/usage_prediction/test_websocket.py b/tests/components/usage_prediction/test_websocket.py index 6e6f8aed79f..9d48ea94bc6 100644 --- a/tests/components/usage_prediction/test_websocket.py +++ b/tests/components/usage_prediction/test_websocket.py @@ -1,7 +1,8 @@ """Test usage_prediction WebSocket API.""" from collections.abc import Generator -from datetime import timedelta +from copy import deepcopy +from datetime import datetime, timedelta from unittest.mock import Mock, patch from freezegun import freeze_time @@ -66,40 +67,45 @@ async def test_caching_behavior( assert await async_setup_component(hass, "usage_prediction", {}) client = await hass_ws_client(hass) - now = dt_util.utcnow() + now = datetime(2026, 8, 26, 9, 0, 0, tzinfo=dt_util.DEFAULT_TIME_ZONE) # First call should fetch from database with freeze_time(now): await client.send_json({"id": 1, "type": "usage_prediction/common_control"}) msg = await client.receive_json() - assert msg["success"] is True - assert msg["result"] == { - "entities": [ - "light.kitchen", - ] - } - assert mock_predict_common_control.call_count == 1 - mock_predict_common_control.return_value["morning"].append("light.bla") + assert msg["success"] is True + assert msg["result"] == { + "entities": [ + "light.kitchen", + ] + } + assert mock_predict_common_control.call_count == 1 + + new_result = deepcopy(mock_predict_common_control.return_value) + new_result["morning"].append("light.bla") + mock_predict_common_control.return_value = new_result # Second call within 24 hours should use cache with freeze_time(now + timedelta(hours=23)): await client.send_json({"id": 2, "type": "usage_prediction/common_control"}) msg = await client.receive_json() - assert msg["success"] is True - assert msg["result"] == { - "entities": [ - "light.kitchen", - ] - } - # Should still be 1 (no new database call) - assert mock_predict_common_control.call_count == 1 + + assert msg["success"] is True + assert msg["result"] == { + "entities": [ + "light.kitchen", + ] + } + # Should still be 1 (no new database call) + assert mock_predict_common_control.call_count == 1 # Third call after 24 hours should fetch from database again with freeze_time(now + timedelta(hours=25)): await client.send_json({"id": 3, "type": "usage_prediction/common_control"}) msg = await client.receive_json() - assert msg["success"] is True - assert msg["result"] == {"entities": ["light.kitchen", "light.bla"]} - # Should now be 2 (new database call) - assert mock_predict_common_control.call_count == 2 + + assert msg["success"] is True + assert msg["result"] == {"entities": ["light.kitchen", "light.bla"]} + # Should now be 2 (new database call) + assert mock_predict_common_control.call_count == 2