Compare commits

...

1 Commits

Author SHA1 Message Date
farmio a385700cc4 Add optional compression to Store 2026-04-23 15:16:05 +02:00
2 changed files with 274 additions and 60 deletions
+146 -60
View File
@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from compression import zstd
from contextlib import suppress from contextlib import suppress
from copy import deepcopy from copy import deepcopy
import inspect import inspect
@@ -48,6 +49,34 @@ STORAGE_MANAGER: HassKey[_StoreManager] = HassKey("storage_manager")
MANAGER_CLEANUP_DELAY = 60 MANAGER_CLEANUP_DELAY = 60
def _load_json_file(path: str | Path) -> json_util.JsonValueType:
"""Load JSON from a file, transparently decompressing .zst files.
Returns ``{}`` (the same sentinel as :func:`json_util.load_json`) when
the file does not exist. Raises :class:`HomeAssistantError` wrapping the
original exception when the file is corrupt or cannot be read.
"""
if not str(path).endswith(".zst"):
return json_util.load_json(path)
try:
with open(path, "rb") as fh:
raw = zstd.decompress(fh.read())
except FileNotFoundError:
_LOGGER.debug("JSON file not found: %s", path)
return {}
except zstd.ZstdError as err:
_LOGGER.exception("Could not decompress storage file: %s", path)
raise HomeAssistantError(f"Error decompressing {path}: {err}") from err
except OSError as err:
_LOGGER.exception("Storage file reading failed: %s", path)
raise HomeAssistantError(f"Error while loading {path}: {err}") from err
try:
return json_util.json_loads(raw)
except json_util.JSON_DECODE_EXCEPTIONS as err:
_LOGGER.exception("Could not parse JSON content: %s", path)
raise HomeAssistantError(f"Error while loading {path}: {err}") from err
async def async_migrator[_T: Mapping[str, Any] | Sequence[Any]]( async def async_migrator[_T: Mapping[str, Any] | Sequence[Any]](
hass: HomeAssistant, hass: HomeAssistant,
old_path: str, old_path: str,
@@ -214,7 +243,7 @@ class _StoreManager:
storage_file: Path = storage_path.joinpath(key) storage_file: Path = storage_path.joinpath(key)
try: try:
if storage_file.is_file(): if storage_file.is_file():
data_preload[key] = json_util.load_json(storage_file) data_preload[key] = _load_json_file(storage_file)
except Exception as ex: # noqa: BLE001 except Exception as ex: # noqa: BLE001
_LOGGER.debug("Error loading %s: %s", key, ex) _LOGGER.debug("Error loading %s: %s", key, ex)
@@ -239,6 +268,7 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
max_readable_version: int | None = None, max_readable_version: int | None = None,
minor_version: int = 1, minor_version: int = 1,
read_only: bool = False, read_only: bool = False,
compress: bool = False,
serialize_in_event_loop: bool = True, serialize_in_event_loop: bool = True,
) -> None: ) -> None:
"""Initialize storage class. """Initialize storage class.
@@ -282,10 +312,36 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
self._next_write_time = 0.0 self._next_write_time = 0.0
self._manager = get_internal_store_manager(hass) self._manager = get_internal_store_manager(hass)
self._serialize_in_event_loop = serialize_in_event_loop self._serialize_in_event_loop = serialize_in_event_loop
self._compress = compress
@cached_property @cached_property
def path(self): def path(self):
"""Return the config path.""" """Return the config path."""
if self._compress:
return self.hass.config.path(STORAGE_DIR, self.key + ".zst")
return self.hass.config.path(STORAGE_DIR, self.key)
@cached_property
def _cache_key(self):
"""Return the cache key used with _StoreManager.
For compressed stores the file on disk is named ``key.zst``, which is
what ``_StoreManager._files`` contains. Using the bare ``key`` would
always produce a "file does not exist" cache hit, so we append the
suffix here to match the real filename.
"""
if self._compress:
return self.key + ".zst"
return self.key
@cached_property
def _uncompressed_path(self):
"""Return the plain (uncompressed) config path.
Used as a fallback when compress=True but only an uncompressed file
exists (e.g. the user manually extracted / edited the file), and to
clean up the old file after a successful compressed write.
"""
return self.hass.config.path(STORAGE_DIR, self.key) return self.hass.config.path(STORAGE_DIR, self.key)
def make_read_only(self) -> None: def make_read_only(self) -> None:
@@ -359,66 +415,19 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
# We make a copy because code might assume it's safe to mutate loaded data # We make a copy because code might assume it's safe to mutate loaded data
# and we don't want that to mess with what we're trying to store. # and we don't want that to mess with what we're trying to store.
data = deepcopy(data) data = deepcopy(data)
elif cache := self._manager.async_fetch(self.key): elif cache := self._manager.async_fetch(self._cache_key):
exists, data = cache exists, data = cache
if not exists: if not exists:
return None return None
else: else:
try: try:
data = await self.hass.async_add_executor_job( data = await self.hass.async_add_executor_job(self._load_data_from_disk)
json_util.load_json, self.path
)
except HomeAssistantError as err: except HomeAssistantError as err:
if isinstance(err.__cause__, JSONDecodeError): if isinstance(err.__cause__, (JSONDecodeError, zstd.ZstdError)):
# If we have a JSONDecodeError, it means the file is corrupt. # If the file is corrupt we log an error, rename it, and
# We can't recover from this, so we'll log an error, rename the file and # return None so startup can continue from a clean slate.
# return None so that we can start with a clean slate which will # The caller can restore from a backup.
# allow startup to continue so they can restore from a backup. await self._async_handle_corrupt_file(err)
isotime = dt_util.utcnow().isoformat()
corrupt_postfix = f".corrupt.{isotime}"
corrupt_path = f"{self.path}{corrupt_postfix}"
await self.hass.async_add_executor_job(
os.rename, self.path, corrupt_path
)
storage_key = self.key
_LOGGER.error(
"Unrecoverable error decoding storage %s at %s; "
"This may indicate an unclean shutdown, invalid syntax "
"from manual edits, or disk corruption; "
"The corrupt file has been saved as %s; "
"It is recommended to restore from backup: %s",
storage_key,
self.path,
corrupt_path,
err,
)
from .issue_registry import ( # noqa: PLC0415
IssueSeverity,
async_create_issue,
)
issue_domain = HOMEASSISTANT_DOMAIN
if (
domain := (storage_key.partition(".")[0])
) and domain in self.hass.config.components:
issue_domain = domain
async_create_issue(
self.hass,
HOMEASSISTANT_DOMAIN,
f"storage_corruption_{storage_key}_{isotime}",
is_fixable=True,
issue_domain=issue_domain,
translation_key="storage_corruption",
is_persistent=True,
severity=IssueSeverity.CRITICAL,
translation_placeholders={
"storage_key": storage_key,
"original_path": self.path,
"corrupt_path": corrupt_path,
"error": str(err),
},
)
return None return None
raise raise
@@ -570,7 +579,7 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
async def _async_handle_write_data(self, *_args): async def _async_handle_write_data(self, *_args):
"""Handle writing the config.""" """Handle writing the config."""
async with self._write_lock: async with self._write_lock:
self._manager.async_invalidate(self.key) self._manager.async_invalidate(self._cache_key)
self._async_cleanup_delay_listener() self._async_cleanup_delay_listener()
self._async_cleanup_final_write_listener() self._async_cleanup_final_write_listener()
@@ -607,6 +616,22 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
mode, json_data = json_helper.prepare_save_json(data, encoder=self._encoder) mode, json_data = json_helper.prepare_save_json(data, encoder=self._encoder)
self._write_prepared_data(mode, json_data) self._write_prepared_data(mode, json_data)
def _load_data_from_disk(self) -> json_util.JsonValueType:
"""Load data from disk.
Called in the executor. For compressed stores the compressed path is
tried first; if it does not exist the plain file is used as a fallback
so that a user-edited uncompressed file is transparently picked up.
Returns ``{}`` (the same sentinel as :func:`json_util.load_json`) when
neither file exists.
"""
data = _load_json_file(self.path)
if data == {} and self._compress:
# .zst not found fall back to the plain file.
data = _load_json_file(self._uncompressed_path)
return data
def _write_prepared_data(self, mode: str, json_data: str | bytes) -> None: def _write_prepared_data(self, mode: str, json_data: str | bytes) -> None:
"""Write the data.""" """Write the data."""
path = self.path path = self.path
@@ -616,7 +641,62 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
write_method = ( write_method = (
write_utf8_file_atomic if self._atomic_writes else write_utf8_file write_utf8_file_atomic if self._atomic_writes else write_utf8_file
) )
write_method(path, json_data, self._private, mode=mode) if self._compress:
# Ensure we have bytes before compressing.
if isinstance(json_data, str):
json_data = json_data.encode("utf-8")
compressed = zstd.compress(json_data)
write_method(path, compressed, self._private, mode="wb")
# Remove the old uncompressed file (migration from plain → compressed).
uncompressed = self._uncompressed_path
if os.path.isfile(uncompressed):
os.unlink(uncompressed)
else:
write_method(path, json_data, self._private, mode=mode)
async def _async_handle_corrupt_file(self, err: Exception) -> None:
"""Rename a corrupt storage file and create a repair issue."""
from .issue_registry import IssueSeverity, async_create_issue # noqa: PLC0415
isotime = dt_util.utcnow().isoformat()
corrupt_postfix = f".corrupt.{isotime}"
corrupt_path = f"{self.path}{corrupt_postfix}"
await self.hass.async_add_executor_job(os.rename, self.path, corrupt_path)
storage_key = self.key
_LOGGER.error(
"Unrecoverable error decoding storage %s at %s; "
"This may indicate an unclean shutdown, invalid syntax "
"from manual edits, or disk corruption; "
"The corrupt file has been saved as %s; "
"It is recommended to restore from backup: %s",
storage_key,
self.path,
corrupt_path,
err,
)
issue_domain = HOMEASSISTANT_DOMAIN
if (
domain := (storage_key.partition(".")[0])
) and domain in self.hass.config.components:
issue_domain = domain
async_create_issue(
self.hass,
HOMEASSISTANT_DOMAIN,
f"storage_corruption_{storage_key}_{isotime}",
is_fixable=True,
issue_domain=issue_domain,
translation_key="storage_corruption",
is_persistent=True,
severity=IssueSeverity.CRITICAL,
translation_placeholders={
"storage_key": storage_key,
"original_path": self.path,
"corrupt_path": corrupt_path,
"error": str(err),
},
)
async def _async_migrate_func(self, old_major_version, old_minor_version, old_data): async def _async_migrate_func(self, old_major_version, old_minor_version, old_data):
"""Migrate to the new version.""" """Migrate to the new version."""
@@ -624,9 +704,15 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
async def async_remove(self) -> None: async def async_remove(self) -> None:
"""Remove all data.""" """Remove all data."""
self._manager.async_invalidate(self.key) self._manager.async_invalidate(self._cache_key)
self._async_cleanup_delay_listener() self._async_cleanup_delay_listener()
self._async_cleanup_final_write_listener() self._async_cleanup_final_write_listener()
with suppress(FileNotFoundError): def _remove_files() -> None:
await self.hass.async_add_executor_job(os.unlink, self.path) with suppress(FileNotFoundError):
os.unlink(self.path)
if self._compress:
with suppress(FileNotFoundError):
os.unlink(self._uncompressed_path)
await self.hass.async_add_executor_job(_remove_files)
+128
View File
@@ -1,6 +1,7 @@
"""Tests for the storage helper.""" """Tests for the storage helper."""
import asyncio import asyncio
from compression import zstd
from datetime import timedelta from datetime import timedelta
import json import json
import os import os
@@ -1382,3 +1383,130 @@ async def test_load_empty_returns_none_and_read_only(
await store.async_save({"new": "data"}) await store.async_save({"new": "data"})
assert hass_storage[MOCK_KEY]["data"] == MOCK_DATA assert hass_storage[MOCK_KEY]["data"] == MOCK_DATA
assert hass_storage[MOCK_KEY]["version"] == 99 assert hass_storage[MOCK_KEY]["version"] == 99
async def test_compress_save_load_round_trip(tmpdir: py.path.local) -> None:
"""Test that a compressed store saves a .zst file and loads back correctly."""
loop = asyncio.get_running_loop()
config_dir = await loop.run_in_executor(None, tmpdir.mkdir, "temp_storage")
async with async_test_home_assistant(config_dir=config_dir.strpath) as hass:
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, compress=True)
await store.async_save(MOCK_DATA)
storage_path = Path(config_dir.strpath) / ".storage"
zst_file = storage_path / (MOCK_KEY + ".zst")
plain_file = storage_path / MOCK_KEY
assert zst_file.is_file()
assert not plain_file.exists()
raw = zstd.decompress(zst_file.read_bytes())
on_disk = json.loads(raw)
assert on_disk["data"] == MOCK_DATA
loaded = await store.async_load()
assert loaded == MOCK_DATA
await hass.async_stop(force=True)
async def test_compress_migrates_plain_to_compressed(tmpdir: py.path.local) -> None:
"""Test that saving with compress=True removes an existing plain file."""
loop = asyncio.get_running_loop()
config_dir = await loop.run_in_executor(None, tmpdir.mkdir, "temp_storage")
async with async_test_home_assistant(config_dir=config_dir.strpath) as hass:
plain_store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
await plain_store.async_save(MOCK_DATA)
storage_path = Path(config_dir.strpath) / ".storage"
plain_file = storage_path / MOCK_KEY
assert plain_file.is_file()
compressed_store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, compress=True)
# Before the first compressed write the plain file is still the fallback.
loaded = await compressed_store.async_load()
assert loaded == MOCK_DATA
# Saving with compress=True should write .zst and remove the plain file.
await compressed_store.async_save(MOCK_DATA2)
zst_file = storage_path / (MOCK_KEY + ".zst")
assert zst_file.is_file()
assert not plain_file.exists()
loaded = await compressed_store.async_load()
assert loaded == MOCK_DATA2
await hass.async_stop(force=True)
async def test_compress_corrupt_file(
tmpdir: py.path.local, caplog: pytest.LogCaptureFixture
) -> None:
"""Test that a corrupt .zst file is handled gracefully."""
loop = asyncio.get_running_loop()
config_dir = await loop.run_in_executor(None, tmpdir.mkdir, "temp_storage")
async with async_test_home_assistant(config_dir=config_dir.strpath) as hass:
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, compress=True)
await store.async_save(MOCK_DATA)
storage_path = Path(config_dir.strpath) / ".storage"
zst_file = storage_path / (MOCK_KEY + ".zst")
def _corrupt_file() -> None:
zst_file.write_bytes(b"this is not valid zstd data")
await hass.async_add_executor_job(_corrupt_file)
loaded = await store.async_load()
assert loaded is None
assert "Unrecoverable error decoding storage" in caplog.text
files = await hass.async_add_executor_job(os.listdir, storage_path)
corrupt_files = [f for f in files if ".corrupt" in f]
assert len(corrupt_files) == 1
await hass.async_stop(force=True)
async def test_compress_store_manager_cache(tmpdir: py.path.local) -> None:
"""Test that compressed stores are cached and served by the store manager."""
loop = asyncio.get_running_loop()
def _setup_mock_storage() -> py.path.local:
config_dir = tmpdir.mkdir("temp_config")
tmp_storage = config_dir.mkdir(".storage")
payload = json.dumps(
{
"version": MOCK_VERSION,
"minor_version": 1,
"key": MOCK_KEY,
"data": MOCK_DATA,
}
).encode()
tmp_storage.join(MOCK_KEY + ".zst").write_binary(zstd.compress(payload))
return config_dir
config_dir = await loop.run_in_executor(None, _setup_mock_storage)
async with async_test_home_assistant(config_dir=config_dir.strpath) as hass:
store_manager = storage.get_internal_store_manager(hass)
await store_manager.async_initialize()
await store_manager.async_preload([MOCK_KEY + ".zst"])
# The cache key for a compressed store is key + ".zst".
result = store_manager.async_fetch(MOCK_KEY + ".zst")
assert result is not None
exists, cached_data = result
assert exists is True
assert cached_data["data"] == MOCK_DATA # type: ignore[index]
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, compress=True)
loaded = await store.async_load()
assert loaded == MOCK_DATA
await hass.async_stop(force=True)