mirror of
https://github.com/home-assistant/core.git
synced 2026-05-07 00:56:50 +02:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a385700cc4 |
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from compression import zstd
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
import inspect
|
||||
@@ -48,6 +49,34 @@ STORAGE_MANAGER: HassKey[_StoreManager] = HassKey("storage_manager")
|
||||
MANAGER_CLEANUP_DELAY = 60
|
||||
|
||||
|
||||
def _load_json_file(path: str | Path) -> json_util.JsonValueType:
|
||||
"""Load JSON from a file, transparently decompressing .zst files.
|
||||
|
||||
Returns ``{}`` (the same sentinel as :func:`json_util.load_json`) when
|
||||
the file does not exist. Raises :class:`HomeAssistantError` wrapping the
|
||||
original exception when the file is corrupt or cannot be read.
|
||||
"""
|
||||
if not str(path).endswith(".zst"):
|
||||
return json_util.load_json(path)
|
||||
try:
|
||||
with open(path, "rb") as fh:
|
||||
raw = zstd.decompress(fh.read())
|
||||
except FileNotFoundError:
|
||||
_LOGGER.debug("JSON file not found: %s", path)
|
||||
return {}
|
||||
except zstd.ZstdError as err:
|
||||
_LOGGER.exception("Could not decompress storage file: %s", path)
|
||||
raise HomeAssistantError(f"Error decompressing {path}: {err}") from err
|
||||
except OSError as err:
|
||||
_LOGGER.exception("Storage file reading failed: %s", path)
|
||||
raise HomeAssistantError(f"Error while loading {path}: {err}") from err
|
||||
try:
|
||||
return json_util.json_loads(raw)
|
||||
except json_util.JSON_DECODE_EXCEPTIONS as err:
|
||||
_LOGGER.exception("Could not parse JSON content: %s", path)
|
||||
raise HomeAssistantError(f"Error while loading {path}: {err}") from err
|
||||
|
||||
|
||||
async def async_migrator[_T: Mapping[str, Any] | Sequence[Any]](
|
||||
hass: HomeAssistant,
|
||||
old_path: str,
|
||||
@@ -214,7 +243,7 @@ class _StoreManager:
|
||||
storage_file: Path = storage_path.joinpath(key)
|
||||
try:
|
||||
if storage_file.is_file():
|
||||
data_preload[key] = json_util.load_json(storage_file)
|
||||
data_preload[key] = _load_json_file(storage_file)
|
||||
except Exception as ex: # noqa: BLE001
|
||||
_LOGGER.debug("Error loading %s: %s", key, ex)
|
||||
|
||||
@@ -239,6 +268,7 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
|
||||
max_readable_version: int | None = None,
|
||||
minor_version: int = 1,
|
||||
read_only: bool = False,
|
||||
compress: bool = False,
|
||||
serialize_in_event_loop: bool = True,
|
||||
) -> None:
|
||||
"""Initialize storage class.
|
||||
@@ -282,10 +312,36 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
|
||||
self._next_write_time = 0.0
|
||||
self._manager = get_internal_store_manager(hass)
|
||||
self._serialize_in_event_loop = serialize_in_event_loop
|
||||
self._compress = compress
|
||||
|
||||
@cached_property
|
||||
def path(self):
|
||||
"""Return the config path."""
|
||||
if self._compress:
|
||||
return self.hass.config.path(STORAGE_DIR, self.key + ".zst")
|
||||
return self.hass.config.path(STORAGE_DIR, self.key)
|
||||
|
||||
@cached_property
|
||||
def _cache_key(self):
|
||||
"""Return the cache key used with _StoreManager.
|
||||
|
||||
For compressed stores the file on disk is named ``key.zst``, which is
|
||||
what ``_StoreManager._files`` contains. Using the bare ``key`` would
|
||||
always produce a "file does not exist" cache hit, so we append the
|
||||
suffix here to match the real filename.
|
||||
"""
|
||||
if self._compress:
|
||||
return self.key + ".zst"
|
||||
return self.key
|
||||
|
||||
@cached_property
|
||||
def _uncompressed_path(self):
|
||||
"""Return the plain (uncompressed) config path.
|
||||
|
||||
Used as a fallback when compress=True but only an uncompressed file
|
||||
exists (e.g. the user manually extracted / edited the file), and to
|
||||
clean up the old file after a successful compressed write.
|
||||
"""
|
||||
return self.hass.config.path(STORAGE_DIR, self.key)
|
||||
|
||||
def make_read_only(self) -> None:
|
||||
@@ -359,66 +415,19 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
|
||||
# We make a copy because code might assume it's safe to mutate loaded data
|
||||
# and we don't want that to mess with what we're trying to store.
|
||||
data = deepcopy(data)
|
||||
elif cache := self._manager.async_fetch(self.key):
|
||||
elif cache := self._manager.async_fetch(self._cache_key):
|
||||
exists, data = cache
|
||||
if not exists:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
data = await self.hass.async_add_executor_job(
|
||||
json_util.load_json, self.path
|
||||
)
|
||||
data = await self.hass.async_add_executor_job(self._load_data_from_disk)
|
||||
except HomeAssistantError as err:
|
||||
if isinstance(err.__cause__, JSONDecodeError):
|
||||
# If we have a JSONDecodeError, it means the file is corrupt.
|
||||
# We can't recover from this, so we'll log an error, rename the file and
|
||||
# return None so that we can start with a clean slate which will
|
||||
# allow startup to continue so they can restore from a backup.
|
||||
isotime = dt_util.utcnow().isoformat()
|
||||
corrupt_postfix = f".corrupt.{isotime}"
|
||||
corrupt_path = f"{self.path}{corrupt_postfix}"
|
||||
await self.hass.async_add_executor_job(
|
||||
os.rename, self.path, corrupt_path
|
||||
)
|
||||
storage_key = self.key
|
||||
_LOGGER.error(
|
||||
"Unrecoverable error decoding storage %s at %s; "
|
||||
"This may indicate an unclean shutdown, invalid syntax "
|
||||
"from manual edits, or disk corruption; "
|
||||
"The corrupt file has been saved as %s; "
|
||||
"It is recommended to restore from backup: %s",
|
||||
storage_key,
|
||||
self.path,
|
||||
corrupt_path,
|
||||
err,
|
||||
)
|
||||
from .issue_registry import ( # noqa: PLC0415
|
||||
IssueSeverity,
|
||||
async_create_issue,
|
||||
)
|
||||
|
||||
issue_domain = HOMEASSISTANT_DOMAIN
|
||||
if (
|
||||
domain := (storage_key.partition(".")[0])
|
||||
) and domain in self.hass.config.components:
|
||||
issue_domain = domain
|
||||
|
||||
async_create_issue(
|
||||
self.hass,
|
||||
HOMEASSISTANT_DOMAIN,
|
||||
f"storage_corruption_{storage_key}_{isotime}",
|
||||
is_fixable=True,
|
||||
issue_domain=issue_domain,
|
||||
translation_key="storage_corruption",
|
||||
is_persistent=True,
|
||||
severity=IssueSeverity.CRITICAL,
|
||||
translation_placeholders={
|
||||
"storage_key": storage_key,
|
||||
"original_path": self.path,
|
||||
"corrupt_path": corrupt_path,
|
||||
"error": str(err),
|
||||
},
|
||||
)
|
||||
if isinstance(err.__cause__, (JSONDecodeError, zstd.ZstdError)):
|
||||
# If the file is corrupt we log an error, rename it, and
|
||||
# return None so startup can continue from a clean slate.
|
||||
# The caller can restore from a backup.
|
||||
await self._async_handle_corrupt_file(err)
|
||||
return None
|
||||
raise
|
||||
|
||||
@@ -570,7 +579,7 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
|
||||
async def _async_handle_write_data(self, *_args):
|
||||
"""Handle writing the config."""
|
||||
async with self._write_lock:
|
||||
self._manager.async_invalidate(self.key)
|
||||
self._manager.async_invalidate(self._cache_key)
|
||||
self._async_cleanup_delay_listener()
|
||||
self._async_cleanup_final_write_listener()
|
||||
|
||||
@@ -607,6 +616,22 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
|
||||
mode, json_data = json_helper.prepare_save_json(data, encoder=self._encoder)
|
||||
self._write_prepared_data(mode, json_data)
|
||||
|
||||
def _load_data_from_disk(self) -> json_util.JsonValueType:
|
||||
"""Load data from disk.
|
||||
|
||||
Called in the executor. For compressed stores the compressed path is
|
||||
tried first; if it does not exist the plain file is used as a fallback
|
||||
so that a user-edited uncompressed file is transparently picked up.
|
||||
|
||||
Returns ``{}`` (the same sentinel as :func:`json_util.load_json`) when
|
||||
neither file exists.
|
||||
"""
|
||||
data = _load_json_file(self.path)
|
||||
if data == {} and self._compress:
|
||||
# .zst not found – fall back to the plain file.
|
||||
data = _load_json_file(self._uncompressed_path)
|
||||
return data
|
||||
|
||||
def _write_prepared_data(self, mode: str, json_data: str | bytes) -> None:
|
||||
"""Write the data."""
|
||||
path = self.path
|
||||
@@ -616,7 +641,62 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
|
||||
write_method = (
|
||||
write_utf8_file_atomic if self._atomic_writes else write_utf8_file
|
||||
)
|
||||
write_method(path, json_data, self._private, mode=mode)
|
||||
if self._compress:
|
||||
# Ensure we have bytes before compressing.
|
||||
if isinstance(json_data, str):
|
||||
json_data = json_data.encode("utf-8")
|
||||
compressed = zstd.compress(json_data)
|
||||
write_method(path, compressed, self._private, mode="wb")
|
||||
# Remove the old uncompressed file (migration from plain → compressed).
|
||||
uncompressed = self._uncompressed_path
|
||||
if os.path.isfile(uncompressed):
|
||||
os.unlink(uncompressed)
|
||||
else:
|
||||
write_method(path, json_data, self._private, mode=mode)
|
||||
|
||||
async def _async_handle_corrupt_file(self, err: Exception) -> None:
|
||||
"""Rename a corrupt storage file and create a repair issue."""
|
||||
from .issue_registry import IssueSeverity, async_create_issue # noqa: PLC0415
|
||||
|
||||
isotime = dt_util.utcnow().isoformat()
|
||||
corrupt_postfix = f".corrupt.{isotime}"
|
||||
corrupt_path = f"{self.path}{corrupt_postfix}"
|
||||
await self.hass.async_add_executor_job(os.rename, self.path, corrupt_path)
|
||||
storage_key = self.key
|
||||
_LOGGER.error(
|
||||
"Unrecoverable error decoding storage %s at %s; "
|
||||
"This may indicate an unclean shutdown, invalid syntax "
|
||||
"from manual edits, or disk corruption; "
|
||||
"The corrupt file has been saved as %s; "
|
||||
"It is recommended to restore from backup: %s",
|
||||
storage_key,
|
||||
self.path,
|
||||
corrupt_path,
|
||||
err,
|
||||
)
|
||||
|
||||
issue_domain = HOMEASSISTANT_DOMAIN
|
||||
if (
|
||||
domain := (storage_key.partition(".")[0])
|
||||
) and domain in self.hass.config.components:
|
||||
issue_domain = domain
|
||||
|
||||
async_create_issue(
|
||||
self.hass,
|
||||
HOMEASSISTANT_DOMAIN,
|
||||
f"storage_corruption_{storage_key}_{isotime}",
|
||||
is_fixable=True,
|
||||
issue_domain=issue_domain,
|
||||
translation_key="storage_corruption",
|
||||
is_persistent=True,
|
||||
severity=IssueSeverity.CRITICAL,
|
||||
translation_placeholders={
|
||||
"storage_key": storage_key,
|
||||
"original_path": self.path,
|
||||
"corrupt_path": corrupt_path,
|
||||
"error": str(err),
|
||||
},
|
||||
)
|
||||
|
||||
async def _async_migrate_func(self, old_major_version, old_minor_version, old_data):
|
||||
"""Migrate to the new version."""
|
||||
@@ -624,9 +704,15 @@ class Store[_T: Mapping[str, Any] | Sequence[Any]]:
|
||||
|
||||
async def async_remove(self) -> None:
|
||||
"""Remove all data."""
|
||||
self._manager.async_invalidate(self.key)
|
||||
self._manager.async_invalidate(self._cache_key)
|
||||
self._async_cleanup_delay_listener()
|
||||
self._async_cleanup_final_write_listener()
|
||||
|
||||
with suppress(FileNotFoundError):
|
||||
await self.hass.async_add_executor_job(os.unlink, self.path)
|
||||
def _remove_files() -> None:
|
||||
with suppress(FileNotFoundError):
|
||||
os.unlink(self.path)
|
||||
if self._compress:
|
||||
with suppress(FileNotFoundError):
|
||||
os.unlink(self._uncompressed_path)
|
||||
|
||||
await self.hass.async_add_executor_job(_remove_files)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for the storage helper."""
|
||||
|
||||
import asyncio
|
||||
from compression import zstd
|
||||
from datetime import timedelta
|
||||
import json
|
||||
import os
|
||||
@@ -1382,3 +1383,130 @@ async def test_load_empty_returns_none_and_read_only(
|
||||
await store.async_save({"new": "data"})
|
||||
assert hass_storage[MOCK_KEY]["data"] == MOCK_DATA
|
||||
assert hass_storage[MOCK_KEY]["version"] == 99
|
||||
|
||||
|
||||
async def test_compress_save_load_round_trip(tmpdir: py.path.local) -> None:
|
||||
"""Test that a compressed store saves a .zst file and loads back correctly."""
|
||||
loop = asyncio.get_running_loop()
|
||||
config_dir = await loop.run_in_executor(None, tmpdir.mkdir, "temp_storage")
|
||||
|
||||
async with async_test_home_assistant(config_dir=config_dir.strpath) as hass:
|
||||
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, compress=True)
|
||||
await store.async_save(MOCK_DATA)
|
||||
|
||||
storage_path = Path(config_dir.strpath) / ".storage"
|
||||
zst_file = storage_path / (MOCK_KEY + ".zst")
|
||||
plain_file = storage_path / MOCK_KEY
|
||||
|
||||
assert zst_file.is_file()
|
||||
assert not plain_file.exists()
|
||||
|
||||
raw = zstd.decompress(zst_file.read_bytes())
|
||||
on_disk = json.loads(raw)
|
||||
assert on_disk["data"] == MOCK_DATA
|
||||
|
||||
loaded = await store.async_load()
|
||||
assert loaded == MOCK_DATA
|
||||
|
||||
await hass.async_stop(force=True)
|
||||
|
||||
|
||||
async def test_compress_migrates_plain_to_compressed(tmpdir: py.path.local) -> None:
|
||||
"""Test that saving with compress=True removes an existing plain file."""
|
||||
loop = asyncio.get_running_loop()
|
||||
config_dir = await loop.run_in_executor(None, tmpdir.mkdir, "temp_storage")
|
||||
|
||||
async with async_test_home_assistant(config_dir=config_dir.strpath) as hass:
|
||||
plain_store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
|
||||
await plain_store.async_save(MOCK_DATA)
|
||||
|
||||
storage_path = Path(config_dir.strpath) / ".storage"
|
||||
plain_file = storage_path / MOCK_KEY
|
||||
assert plain_file.is_file()
|
||||
|
||||
compressed_store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, compress=True)
|
||||
|
||||
# Before the first compressed write the plain file is still the fallback.
|
||||
loaded = await compressed_store.async_load()
|
||||
assert loaded == MOCK_DATA
|
||||
|
||||
# Saving with compress=True should write .zst and remove the plain file.
|
||||
await compressed_store.async_save(MOCK_DATA2)
|
||||
|
||||
zst_file = storage_path / (MOCK_KEY + ".zst")
|
||||
assert zst_file.is_file()
|
||||
assert not plain_file.exists()
|
||||
|
||||
loaded = await compressed_store.async_load()
|
||||
assert loaded == MOCK_DATA2
|
||||
|
||||
await hass.async_stop(force=True)
|
||||
|
||||
|
||||
async def test_compress_corrupt_file(
|
||||
tmpdir: py.path.local, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test that a corrupt .zst file is handled gracefully."""
|
||||
loop = asyncio.get_running_loop()
|
||||
config_dir = await loop.run_in_executor(None, tmpdir.mkdir, "temp_storage")
|
||||
|
||||
async with async_test_home_assistant(config_dir=config_dir.strpath) as hass:
|
||||
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, compress=True)
|
||||
await store.async_save(MOCK_DATA)
|
||||
|
||||
storage_path = Path(config_dir.strpath) / ".storage"
|
||||
zst_file = storage_path / (MOCK_KEY + ".zst")
|
||||
|
||||
def _corrupt_file() -> None:
|
||||
zst_file.write_bytes(b"this is not valid zstd data")
|
||||
|
||||
await hass.async_add_executor_job(_corrupt_file)
|
||||
|
||||
loaded = await store.async_load()
|
||||
assert loaded is None
|
||||
assert "Unrecoverable error decoding storage" in caplog.text
|
||||
|
||||
files = await hass.async_add_executor_job(os.listdir, storage_path)
|
||||
corrupt_files = [f for f in files if ".corrupt" in f]
|
||||
assert len(corrupt_files) == 1
|
||||
|
||||
await hass.async_stop(force=True)
|
||||
|
||||
|
||||
async def test_compress_store_manager_cache(tmpdir: py.path.local) -> None:
|
||||
"""Test that compressed stores are cached and served by the store manager."""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _setup_mock_storage() -> py.path.local:
|
||||
config_dir = tmpdir.mkdir("temp_config")
|
||||
tmp_storage = config_dir.mkdir(".storage")
|
||||
payload = json.dumps(
|
||||
{
|
||||
"version": MOCK_VERSION,
|
||||
"minor_version": 1,
|
||||
"key": MOCK_KEY,
|
||||
"data": MOCK_DATA,
|
||||
}
|
||||
).encode()
|
||||
tmp_storage.join(MOCK_KEY + ".zst").write_binary(zstd.compress(payload))
|
||||
return config_dir
|
||||
|
||||
config_dir = await loop.run_in_executor(None, _setup_mock_storage)
|
||||
|
||||
async with async_test_home_assistant(config_dir=config_dir.strpath) as hass:
|
||||
store_manager = storage.get_internal_store_manager(hass)
|
||||
await store_manager.async_initialize()
|
||||
await store_manager.async_preload([MOCK_KEY + ".zst"])
|
||||
|
||||
# The cache key for a compressed store is key + ".zst".
|
||||
result = store_manager.async_fetch(MOCK_KEY + ".zst")
|
||||
assert result is not None
|
||||
exists, cached_data = result
|
||||
assert exists is True
|
||||
assert cached_data["data"] == MOCK_DATA # type: ignore[index]
|
||||
|
||||
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY, compress=True)
|
||||
loaded = await store.async_load()
|
||||
assert loaded == MOCK_DATA
|
||||
|
||||
await hass.async_stop(force=True)
|
||||
|
||||
Reference in New Issue
Block a user