Make caching more efficient

This commit is contained in:
Paulus Schoutsen
2025-08-26 20:55:21 +02:00
parent 9f751a6380
commit 1fed4b83ef
2 changed files with 37 additions and 33 deletions

View File

@@ -67,21 +67,17 @@ async def get_cached_common_control(
# Create a unique storage key for this user # Create a unique storage key for this user
storage_key = f"{STORAGE_KEY_PREFIX}.{user_id}" storage_key = f"{STORAGE_KEY_PREFIX}.{user_id}"
# Get or create store for this user
if storage_key not in hass.data[DOMAIN]: if storage_key not in hass.data[DOMAIN]:
hass.data[DOMAIN][storage_key] = Store[dict[str, Any]]( store = Store[dict[str, Any]](hass, STORAGE_VERSION, storage_key)
hass, STORAGE_VERSION, storage_key, private=True cached_data: dict[str, Any] | None = await store.async_load()
) hass.data[DOMAIN][storage_key] = cached_data
store: Store[dict[str, Any]] = hass.data[DOMAIN][storage_key] cached_data = hass.data[DOMAIN].get(storage_key)
# Load cached data
cached_data = await store.async_load()
# Check if cache is valid (less than 24 hours old) # Check if cache is valid (less than 24 hours old)
now = dt_util.utcnow() now = dt_util.utcnow()
if cached_data is not None: if cached_data is not None:
cached_time = dt_util.parse_datetime(cached_data.get("timestamp", "")) cached_time = dt_util.parse_datetime(cached_data["timestamp"])
if cached_time and (now - cached_time) < CACHE_DURATION: if cached_time and (now - cached_time) < CACHE_DURATION:
# Cache is still valid, return the cached predictions # Cache is still valid, return the cached predictions
return cached_data["predictions"] return cached_data["predictions"]
@@ -90,12 +86,14 @@ async def get_cached_common_control(
predictions = await common_control.async_predict_common_control(hass, user_id) predictions = await common_control.async_predict_common_control(hass, user_id)
# Store the new data with timestamp # Store the new data with timestamp
cache_data = { cached_data = {
"timestamp": now.isoformat(), "timestamp": now.isoformat(),
"predictions": predictions, "predictions": predictions,
} }
# Save to cache # Save to cache
await store.async_save(cache_data) store = Store[dict[str, Any]](hass, STORAGE_VERSION, storage_key)
store.async_delay_save(lambda: cached_data)
hass.data[DOMAIN][storage_key] = cached_data
return predictions return predictions

View File

@@ -1,7 +1,8 @@
"""Test usage_prediction WebSocket API.""" """Test usage_prediction WebSocket API."""
from collections.abc import Generator from collections.abc import Generator
from datetime import timedelta from copy import deepcopy
from datetime import datetime, timedelta
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from freezegun import freeze_time from freezegun import freeze_time
@@ -66,12 +67,13 @@ async def test_caching_behavior(
assert await async_setup_component(hass, "usage_prediction", {}) assert await async_setup_component(hass, "usage_prediction", {})
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
now = dt_util.utcnow() now = datetime(2026, 8, 26, 9, 0, 0, tzinfo=dt_util.DEFAULT_TIME_ZONE)
# First call should fetch from database # First call should fetch from database
with freeze_time(now): with freeze_time(now):
await client.send_json({"id": 1, "type": "usage_prediction/common_control"}) await client.send_json({"id": 1, "type": "usage_prediction/common_control"})
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] is True assert msg["success"] is True
assert msg["result"] == { assert msg["result"] == {
"entities": [ "entities": [
@@ -80,12 +82,15 @@ async def test_caching_behavior(
} }
assert mock_predict_common_control.call_count == 1 assert mock_predict_common_control.call_count == 1
mock_predict_common_control.return_value["morning"].append("light.bla") new_result = deepcopy(mock_predict_common_control.return_value)
new_result["morning"].append("light.bla")
mock_predict_common_control.return_value = new_result
# Second call within 24 hours should use cache # Second call within 24 hours should use cache
with freeze_time(now + timedelta(hours=23)): with freeze_time(now + timedelta(hours=23)):
await client.send_json({"id": 2, "type": "usage_prediction/common_control"}) await client.send_json({"id": 2, "type": "usage_prediction/common_control"})
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] is True assert msg["success"] is True
assert msg["result"] == { assert msg["result"] == {
"entities": [ "entities": [
@@ -99,6 +104,7 @@ async def test_caching_behavior(
with freeze_time(now + timedelta(hours=25)): with freeze_time(now + timedelta(hours=25)):
await client.send_json({"id": 3, "type": "usage_prediction/common_control"}) await client.send_json({"id": 3, "type": "usage_prediction/common_control"})
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] is True assert msg["success"] is True
assert msg["result"] == {"entities": ["light.kitchen", "light.bla"]} assert msg["result"] == {"entities": ["light.kitchen", "light.bla"]}
# Should now be 2 (new database call) # Should now be 2 (new database call)