Make caching more efficient

This commit is contained in:
Paulus Schoutsen
2025-08-26 20:55:21 +02:00
parent 9f751a6380
commit 1fed4b83ef
2 changed files with 37 additions and 33 deletions

View File

@@ -67,21 +67,17 @@ async def get_cached_common_control(
# Create a unique storage key for this user
storage_key = f"{STORAGE_KEY_PREFIX}.{user_id}"
# Get or create store for this user
if storage_key not in hass.data[DOMAIN]:
hass.data[DOMAIN][storage_key] = Store[dict[str, Any]](
hass, STORAGE_VERSION, storage_key, private=True
)
store = Store[dict[str, Any]](hass, STORAGE_VERSION, storage_key)
cached_data: dict[str, Any] | None = await store.async_load()
hass.data[DOMAIN][storage_key] = cached_data
store: Store[dict[str, Any]] = hass.data[DOMAIN][storage_key]
# Load cached data
cached_data = await store.async_load()
cached_data = hass.data[DOMAIN].get(storage_key)
# Check if cache is valid (less than 24 hours old)
now = dt_util.utcnow()
if cached_data is not None:
cached_time = dt_util.parse_datetime(cached_data.get("timestamp", ""))
cached_time = dt_util.parse_datetime(cached_data["timestamp"])
if cached_time and (now - cached_time) < CACHE_DURATION:
# Cache is still valid, return the cached predictions
return cached_data["predictions"]
@@ -90,12 +86,14 @@ async def get_cached_common_control(
predictions = await common_control.async_predict_common_control(hass, user_id)
# Store the new data with timestamp
cache_data = {
cached_data = {
"timestamp": now.isoformat(),
"predictions": predictions,
}
# Save to cache
await store.async_save(cache_data)
store = Store[dict[str, Any]](hass, STORAGE_VERSION, storage_key)
store.async_delay_save(lambda: cached_data)
hass.data[DOMAIN][storage_key] = cached_data
return predictions

View File

@@ -1,7 +1,8 @@
"""Test usage_prediction WebSocket API."""
from collections.abc import Generator
from datetime import timedelta
from copy import deepcopy
from datetime import datetime, timedelta
from unittest.mock import Mock, patch
from freezegun import freeze_time
@@ -66,40 +67,45 @@ async def test_caching_behavior(
assert await async_setup_component(hass, "usage_prediction", {})
client = await hass_ws_client(hass)
now = dt_util.utcnow()
now = datetime(2026, 8, 26, 9, 0, 0, tzinfo=dt_util.DEFAULT_TIME_ZONE)
# First call should fetch from database
with freeze_time(now):
await client.send_json({"id": 1, "type": "usage_prediction/common_control"})
msg = await client.receive_json()
assert msg["success"] is True
assert msg["result"] == {
"entities": [
"light.kitchen",
]
}
assert mock_predict_common_control.call_count == 1
mock_predict_common_control.return_value["morning"].append("light.bla")
assert msg["success"] is True
assert msg["result"] == {
"entities": [
"light.kitchen",
]
}
assert mock_predict_common_control.call_count == 1
new_result = deepcopy(mock_predict_common_control.return_value)
new_result["morning"].append("light.bla")
mock_predict_common_control.return_value = new_result
# Second call within 24 hours should use cache
with freeze_time(now + timedelta(hours=23)):
await client.send_json({"id": 2, "type": "usage_prediction/common_control"})
msg = await client.receive_json()
assert msg["success"] is True
assert msg["result"] == {
"entities": [
"light.kitchen",
]
}
# Should still be 1 (no new database call)
assert mock_predict_common_control.call_count == 1
assert msg["success"] is True
assert msg["result"] == {
"entities": [
"light.kitchen",
]
}
# Should still be 1 (no new database call)
assert mock_predict_common_control.call_count == 1
# Third call after 24 hours should fetch from database again
with freeze_time(now + timedelta(hours=25)):
await client.send_json({"id": 3, "type": "usage_prediction/common_control"})
msg = await client.receive_json()
assert msg["success"] is True
assert msg["result"] == {"entities": ["light.kitchen", "light.bla"]}
# Should now be 2 (new database call)
assert mock_predict_common_control.call_count == 2
assert msg["success"] is True
assert msg["result"] == {"entities": ["light.kitchen", "light.bla"]}
# Should now be 2 (new database call)
assert mock_predict_common_control.call_count == 2