Make caching more efficient

This commit is contained in:
Paulus Schoutsen
2025-08-26 20:55:21 +02:00
parent 9f751a6380
commit 1fed4b83ef
2 changed files with 37 additions and 33 deletions

View File

@@ -67,21 +67,17 @@ async def get_cached_common_control(
# Create a unique storage key for this user # Create a unique storage key for this user
storage_key = f"{STORAGE_KEY_PREFIX}.{user_id}" storage_key = f"{STORAGE_KEY_PREFIX}.{user_id}"
# Get or create store for this user
if storage_key not in hass.data[DOMAIN]: if storage_key not in hass.data[DOMAIN]:
hass.data[DOMAIN][storage_key] = Store[dict[str, Any]]( store = Store[dict[str, Any]](hass, STORAGE_VERSION, storage_key)
hass, STORAGE_VERSION, storage_key, private=True cached_data: dict[str, Any] | None = await store.async_load()
) hass.data[DOMAIN][storage_key] = cached_data
store: Store[dict[str, Any]] = hass.data[DOMAIN][storage_key] cached_data = hass.data[DOMAIN].get(storage_key)
# Load cached data
cached_data = await store.async_load()
# Check if cache is valid (less than 24 hours old) # Check if cache is valid (less than 24 hours old)
now = dt_util.utcnow() now = dt_util.utcnow()
if cached_data is not None: if cached_data is not None:
cached_time = dt_util.parse_datetime(cached_data.get("timestamp", "")) cached_time = dt_util.parse_datetime(cached_data["timestamp"])
if cached_time and (now - cached_time) < CACHE_DURATION: if cached_time and (now - cached_time) < CACHE_DURATION:
# Cache is still valid, return the cached predictions # Cache is still valid, return the cached predictions
return cached_data["predictions"] return cached_data["predictions"]
@@ -90,12 +86,14 @@ async def get_cached_common_control(
predictions = await common_control.async_predict_common_control(hass, user_id) predictions = await common_control.async_predict_common_control(hass, user_id)
# Store the new data with timestamp # Store the new data with timestamp
cache_data = { cached_data = {
"timestamp": now.isoformat(), "timestamp": now.isoformat(),
"predictions": predictions, "predictions": predictions,
} }
# Save to cache # Save to cache
await store.async_save(cache_data) store = Store[dict[str, Any]](hass, STORAGE_VERSION, storage_key)
store.async_delay_save(lambda: cached_data)
hass.data[DOMAIN][storage_key] = cached_data
return predictions return predictions

View File

@@ -1,7 +1,8 @@
"""Test usage_prediction WebSocket API.""" """Test usage_prediction WebSocket API."""
from collections.abc import Generator from collections.abc import Generator
from datetime import timedelta from copy import deepcopy
from datetime import datetime, timedelta
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from freezegun import freeze_time from freezegun import freeze_time
@@ -66,40 +67,45 @@ async def test_caching_behavior(
assert await async_setup_component(hass, "usage_prediction", {}) assert await async_setup_component(hass, "usage_prediction", {})
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
now = dt_util.utcnow() now = datetime(2026, 8, 26, 9, 0, 0, tzinfo=dt_util.DEFAULT_TIME_ZONE)
# First call should fetch from database # First call should fetch from database
with freeze_time(now): with freeze_time(now):
await client.send_json({"id": 1, "type": "usage_prediction/common_control"}) await client.send_json({"id": 1, "type": "usage_prediction/common_control"})
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] is True
assert msg["result"] == {
"entities": [
"light.kitchen",
]
}
assert mock_predict_common_control.call_count == 1
mock_predict_common_control.return_value["morning"].append("light.bla") assert msg["success"] is True
assert msg["result"] == {
"entities": [
"light.kitchen",
]
}
assert mock_predict_common_control.call_count == 1
new_result = deepcopy(mock_predict_common_control.return_value)
new_result["morning"].append("light.bla")
mock_predict_common_control.return_value = new_result
# Second call within 24 hours should use cache # Second call within 24 hours should use cache
with freeze_time(now + timedelta(hours=23)): with freeze_time(now + timedelta(hours=23)):
await client.send_json({"id": 2, "type": "usage_prediction/common_control"}) await client.send_json({"id": 2, "type": "usage_prediction/common_control"})
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] is True
assert msg["result"] == { assert msg["success"] is True
"entities": [ assert msg["result"] == {
"light.kitchen", "entities": [
] "light.kitchen",
} ]
# Should still be 1 (no new database call) }
assert mock_predict_common_control.call_count == 1 # Should still be 1 (no new database call)
assert mock_predict_common_control.call_count == 1
# Third call after 24 hours should fetch from database again # Third call after 24 hours should fetch from database again
with freeze_time(now + timedelta(hours=25)): with freeze_time(now + timedelta(hours=25)):
await client.send_json({"id": 3, "type": "usage_prediction/common_control"}) await client.send_json({"id": 3, "type": "usage_prediction/common_control"})
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] is True
assert msg["result"] == {"entities": ["light.kitchen", "light.bla"]} assert msg["success"] is True
# Should now be 2 (new database call) assert msg["result"] == {"entities": ["light.kitchen", "light.bla"]}
assert mock_predict_common_control.call_count == 2 # Should now be 2 (new database call)
assert mock_predict_common_control.call_count == 2