Compare commits

...

1 Commits

Author SHA1 Message Date
farmio a385700cc4 Add optional compression to Store 2026-04-23 15:16:05 +02:00
2 changed files with 274 additions and 60 deletions
+146 -60
View File
@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Iterable, Mapping, Sequence
from compression import zstd
from contextlib import suppress
from copy import deepcopy
import inspect
@@ -48,6 +49,34 @@ STORAGE_MANAGER: HassKey[_StoreManager] = HassKey("storage_manager")
MANAGER_CLEANUP_DELAY = 60
def _load_json_file(path: str | Path) -> json_util.JsonValueType:
"""Load JSON from a file, transparently decompressing .zst files.
Returns ``{}`` (the same sentinel as :func:`json_util.load_json`) when
the file does not exist. Raises :class:`HomeAssistantError` wrapping the
original exception when the file is corrupt or cannot be read.
"""
if not str(path).endswith(".zst"):
return json_util.load_json(path)
try:
with open(path, "rb") as fh:
raw = zstd.decompress(fh.read())
except FileNotFoundError:
_LOGGER.debug("JSON file not found: %s", path)
return {}
except zstd.ZstdError as err:
_LOGGER.exception("Could not decompress storage file: %s", path)
raise HomeAssistantError(f"Error decompressing {path}: {err}") from err
except OSError as err:
_LOGGER.exception("Storage file reading failed: %s", path)
raise HomeAssistantError(f"Error while loading {path}: {err}") from err
try:
return json_util.json_loads(raw)
except json_util.JSON_DECODE_EXCEPTIONS as err:
_LOGGER.exception("Could not parse JSON content: %s", path)
raise HomeAssistantError(f"Error while loading {path}: {err}") from err
async def async_migrator[_T: Mapping[str, Any] | Sequence[Any]](
hass: HomeAssistant,
old_path: str,
@@ -214,7 +243,7 @@ class _StoreManager:
storage_file: Path = storage_path.joinpath(key)
try:
if storage_file.is_file():
data_preload[key] = json_util.load_json(storage_file)
data_preload[key] = _load_json_file(storage_file)
except Exception as ex: # noqa: BLE001
_LOGGER.debug("Error loading %s: %s", key, ex)
@@ -239,6 +268,7 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
max_readable_version: int | None = None,
minor_version: int = 1,
read_only: bool = False,
compress: bool = False,
serialize_in_event_loop: bool = True,
) -> None:
"""Initialize storage class.
@@ -282,10 +312,36 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
self._next_write_time = 0.0
self._manager = get_internal_store_manager(hass)
self._serialize_in_event_loop = serialize_in_event_loop
self._compress = compress
@cached_property
def path(self):
"""Return the config path."""
if self._compress:
return self.hass.config.path(STORAGE_DIR, self.key + ".zst")
return self.hass.config.path(STORAGE_DIR, self.key)
@cached_property
def _cache_key(self):
"""Return the cache key used with _StoreManager.
For compressed stores the file on disk is named ``key.zst``, which is
what ``_StoreManager._files`` contains. Using the bare ``key`` would
always produce a "file does not exist" cache hit, so we append the
suffix here to match the real filename.
"""
if self._compress:
return self.key + ".zst"
return self.key
@cached_property
def _uncompressed_path(self):
"""Return the plain (uncompressed) config path.
Used as a fallback when compress=True but only an uncompressed file
exists (e.g. the user manually extracted / edited the file), and to
clean up the old file after a successful compressed write.
"""
return self.hass.config.path(STORAGE_DIR, self.key)
def make_read_only(self) -> None:
@@ -359,66 +415,19 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
# We make a copy because code might assume it's safe to mutate loaded data
# and we don't want that to mess with what we're trying to store.
data = deepcopy(data)
elif cache := self._manager.async_fetch(self.key):
elif cache := self._manager.async_fetch(self._cache_key):
exists, data = cache
if not exists:
return None
else:
try:
data = await self.hass.async_add_executor_job(
json_util.load_json, self.path
)
data = await self.hass.async_add_executor_job(self._load_data_from_disk)
except HomeAssistantError as err:
if isinstance(err.__cause__, JSONDecodeError):
# If we have a JSONDecodeError, it means the file is corrupt.
# We can't recover from this, so we'll log an error, rename the file and
# return None so that we can start with a clean slate which will
# allow startup to continue so they can restore from a backup.
isotime = dt_util.utcnow().isoformat()
corrupt_postfix = f".corrupt.{isotime}"
corrupt_path = f"{self.path}{corrupt_postfix}"
await self.hass.async_add_executor_job(
os.rename, self.path, corrupt_path
)
storage_key = self.key
_LOGGER.error(
"Unrecoverable error decoding storage %s at %s; "
"This may indicate an unclean shutdown, invalid syntax "
"from manual edits, or disk corruption; "
"The corrupt file has been saved as %s; "
"It is recommended to restore from backup: %s",
storage_key,
self.path,
corrupt_path,
err,
)
from .issue_registry import ( # noqa: PLC0415
IssueSeverity,
async_create_issue,
)
issue_domain = HOMEASSISTANT_DOMAIN
if (
domain := (storage_key.partition(".")[0])
) and domain in self.hass.config.components:
issue_domain = domain
async_create_issue(
self.hass,
HOMEASSISTANT_DOMAIN,
f"storage_corruption_{storage_key}_{isotime}",
is_fixable=True,
issue_domain=issue_domain,
translation_key="storage_corruption",
is_persistent=True,
severity=IssueSeverity.CRITICAL,
translation_placeholders={
"storage_key": storage_key,
"original_path": self.path,
"corrupt_path": corrupt_path,
"error": str(err),
},
)
if isinstance(err.__cause__, (JSONDecodeError, zstd.ZstdError)):
# If the file is corrupt we log an error, rename it, and
# return None so startup can continue from a clean slate.
# The caller can restore from a backup.
await self._async_handle_corrupt_file(err)
return None
raise
@@ -570,7 +579,7 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
async def _async_handle_write_data(self, *_args):
"""Handle writing the config."""
async with self._write_lock:
self._manager.async_invalidate(self.key)
self._manager.async_invalidate(self._cache_key)
self._async_cleanup_delay_listener()
self._async_cleanup_final_write_listener()
@@ -607,6 +616,22 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
mode, json_data = json_helper.prepare_save_json(data, encoder=self._encoder)
self._write_prepared_data(mode, json_data)
def _load_data_from_disk(self) -> json_util.JsonValueType:
"""Load data from disk.
Called in the executor. For compressed stores the compressed path is
tried first; if it does not exist the plain file is used as a fallback
so that a user-edited uncompressed file is transparently picked up.
Returns ``{}`` (the same sentinel as :func:`json_util.load_json`) when
neither file exists.
"""
data = _load_json_file(self.path)
if data == {} and self._compress:
# .zst not found fall back to the plain file.
data = _load_json_file(self._uncompressed_path)
return data
def _write_prepared_data(self, mode: str, json_data: str | bytes) -> None:
"""Write the data."""
path = self.path
@@ -616,7 +641,62 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
write_method = (
write_utf8_file_atomic if self._atomic_writes else write_utf8_file
)
write_method(path, json_data, self._private, mode=mode)
if self._compress:
# Ensure we have bytes before compressing.
if isinstance(json_data, str):
json_data = json_data.encode("utf-8")
compressed = zstd.compress(json_data)
write_method(path, compressed, self._private, mode="wb")
# Remove the old uncompressed file (migration from plain → compressed).
uncompressed = self._uncompressed_path
if os.path.isfile(uncompressed):
os.unlink(uncompressed)
else:
write_method(path, json_data, self._private, mode=mode)
async def _async_handle_corrupt_file(self, err: Exception) -> None:
"""Rename a corrupt storage file and create a repair issue."""
from .issue_registry import IssueSeverity, async_create_issue # noqa: PLC0415
isotime = dt_util.utcnow().isoformat()
corrupt_postfix = f".corrupt.{isotime}"
corrupt_path = f"{self.path}{corrupt_postfix}"
await self.hass.async_add_executor_job(os.rename, self.path, corrupt_path)
storage_key = self.key
_LOGGER.error(
"Unrecoverable error decoding storage %s at %s; "
"This may indicate an unclean shutdown, invalid syntax "
"from manual edits, or disk corruption; "
"The corrupt file has been saved as %s; "
"It is recommended to restore from backup: %s",
storage_key,
self.path,
corrupt_path,
err,
)
issue_domain = HOMEASSISTANT_DOMAIN
if (
domain := (storage_key.partition(".")[0])
) and domain in self.hass.config.components:
issue_domain = domain
async_create_issue(
self.hass,
HOMEASSISTANT_DOMAIN,
f"storage_corruption_{storage_key}_{isotime}",
is_fixable=True,
issue_domain=issue_domain,
translation_key="storage_corruption",
is_persistent=True,
severity=IssueSeverity.CRITICAL,
translation_placeholders={
"storage_key": storage_key,
"original_path": self.path,
"corrupt_path": corrupt_path,
"error": str(err),
},
)
async def _async_migrate_func(self, old_major_version, old_minor_version, old_data):
"""Migrate to the new version."""
@@ -624,9 +704,15 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
async def async_remove(self) -> None:
"""Remove all data."""
self._manager.async_invalidate(self.key)
self._manager.async_invalidate(self._cache_key)
self._async_cleanup_delay_listener()
self._async_cleanup_final_write_listener()
with suppress(FileNotFoundError):
await self.hass.async_add_executor_job(os.unlink, self.path)
def _remove_files() -> None:
with suppress(FileNotFoundError):
os.unlink(self.path)
if self._compress:
with suppress(FileNotFoundError):
os.unlink(self._uncompressed_path)
await self.hass.async_add_executor_job(_remove_files)
+128
View File
@@ -1,6 +1,7 @@
"""Tests for the storage helper."""
import asyncio
from compression import zstd
from datetime import timedelta
import json
import os
@@ -1382,3 +1383,130 @@ async def test_load_empty_returns_none_and_read_only(
await store.async_save({"new": "data"})
assert hass_storage[MOCK_KEY]["data"] == MOCK_DATA
assert hass_storage[MOCK_KEY]["version"] == 99
async def test_compress_save_load_round_trip(tmpdir: py.path.local) -> None:
"""Test that a compressed store saves a .zst file and loads back correctly."""
loop = asyncio.get_running_loop()
config_dir = await loop.run_in_executor(None, tmpdir.mkdir, "temp_storage")
async with async_test_home_assistant(config_dir=config_dir.strpath) as hass:
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, compress=True)
await store.async_save(MOCK_DATA)
storage_path = Path(config_dir.strpath) / ".storage"
zst_file = storage_path / (MOCK_KEY + ".zst")
plain_file = storage_path / MOCK_KEY
assert zst_file.is_file()
assert not plain_file.exists()
raw = zstd.decompress(zst_file.read_bytes())
on_disk = json.loads(raw)
assert on_disk["data"] == MOCK_DATA
loaded = await store.async_load()
assert loaded == MOCK_DATA
await hass.async_stop(force=True)
async def test_compress_migrates_plain_to_compressed(tmpdir: py.path.local) -> None:
"""Test that saving with compress=True removes an existing plain file."""
loop = asyncio.get_running_loop()
config_dir = await loop.run_in_executor(None, tmpdir.mkdir, "temp_storage")
async with async_test_home_assistant(config_dir=config_dir.strpath) as hass:
plain_store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
await plain_store.async_save(MOCK_DATA)
storage_path = Path(config_dir.strpath) / ".storage"
plain_file = storage_path / MOCK_KEY
assert plain_file.is_file()
compressed_store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, compress=True)
# Before the first compressed write the plain file is still the fallback.
loaded = await compressed_store.async_load()
assert loaded == MOCK_DATA
# Saving with compress=True should write .zst and remove the plain file.
await compressed_store.async_save(MOCK_DATA2)
zst_file = storage_path / (MOCK_KEY + ".zst")
assert zst_file.is_file()
assert not plain_file.exists()
loaded = await compressed_store.async_load()
assert loaded == MOCK_DATA2
await hass.async_stop(force=True)
async def test_compress_corrupt_file(
tmpdir: py.path.local, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that a corrupt .zst file is handled gracefully."""
loop = asyncio.get_running_loop()
config_dir = await loop.run_in_executor(None, tmpdir.mkdir, "temp_storage")
async with async_test_home_assistant(config_dir=config_dir.strpath) as hass:
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, compress=True)
await store.async_save(MOCK_DATA)
storage_path = Path(config_dir.strpath) / ".storage"
zst_file = storage_path / (MOCK_KEY + ".zst")
def _corrupt_file() -> None:
zst_file.write_bytes(b"this is not valid zstd data")
await hass.async_add_executor_job(_corrupt_file)
loaded = await store.async_load()
assert loaded is None
assert "Unrecoverable error decoding storage" in caplog.text
files = await hass.async_add_executor_job(os.listdir, storage_path)
corrupt_files = [f for f in files if ".corrupt" in f]
assert len(corrupt_files) == 1
await hass.async_stop(force=True)
async def test_compress_store_manager_cache(tmpdir: py.path.local) -> None:
"""Test that compressed stores are cached and served by the store manager."""
loop = asyncio.get_running_loop()
def _setup_mock_storage() -> py.path.local:
config_dir = tmpdir.mkdir("temp_config")
tmp_storage = config_dir.mkdir(".storage")
payload = json.dumps(
{
"version": MOCK_VERSION,
"minor_version": 1,
"key": MOCK_KEY,
"data": MOCK_DATA,
}
).encode()
tmp_storage.join(MOCK_KEY + ".zst").write_binary(zstd.compress(payload))
return config_dir
config_dir = await loop.run_in_executor(None, _setup_mock_storage)
async with async_test_home_assistant(config_dir=config_dir.strpath) as hass:
store_manager = storage.get_internal_store_manager(hass)
await store_manager.async_initialize()
await store_manager.async_preload([MOCK_KEY + ".zst"])
# The cache key for a compressed store is key + ".zst".
result = store_manager.async_fetch(MOCK_KEY + ".zst")
assert result is not None
exists, cached_data = result
assert exists is True
assert cached_data["data"] == MOCK_DATA # type: ignore[index]
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, compress=True)
loaded = await store.async_load()
assert loaded == MOCK_DATA
await hass.async_stop(force=True)