diff --git a/homeassistant/components/backup/manager.py b/homeassistant/components/backup/manager.py index 76e1c261e31..73bbfafdcf8 100644 --- a/homeassistant/components/backup/manager.py +++ b/homeassistant/components/backup/manager.py @@ -14,7 +14,7 @@ from pathlib import Path, PurePath import shutil import tarfile import time -from typing import TYPE_CHECKING, Any, Protocol, TypedDict +from typing import IO, TYPE_CHECKING, Any, Protocol, TypedDict, cast import aiohttp from securetar import SecureTarFile, atomic_contents_add @@ -31,6 +31,7 @@ from homeassistant.helpers import ( from homeassistant.helpers.json import json_bytes from homeassistant.util import dt as dt_util +from . import util as backup_util from .agent import ( BackupAgent, BackupAgentError, @@ -48,7 +49,13 @@ from .const import ( ) from .models import AgentBackup, BackupManagerError, Folder from .store import BackupStore -from .util import make_backup_dir, read_backup, validate_password +from .util import ( + AsyncIteratorReader, + make_backup_dir, + read_backup, + validate_password, + validate_password_stream, +) @dataclass(frozen=True, kw_only=True, slots=True) @@ -248,6 +255,14 @@ class BackupReaderWriterError(HomeAssistantError): class IncorrectPasswordError(BackupReaderWriterError): """Raised when the password is incorrect.""" + _message = "The password provided is incorrect." + + +class DecryptOnDowloadNotSupported(BackupManagerError): + """Raised when on-the-fly decryption is not supported.""" + + _message = "On-the-fly decryption is not supported for this backup." + class BackupManager: """Define the format that backup managers can have.""" @@ -990,6 +1005,39 @@ class BackupManager: translation_placeholders={"failed_agents": ", ".join(agent_errors)}, ) + async def async_can_decrypt_on_download( + self, + backup_id: str, + *, + agent_id: str, + password: str | None, + ) -> None: + """Check if we are able to decrypt the backup on download.""" + try: + agent = self.backup_agents[agent_id] + except KeyError as err: + raise BackupManagerError(f"Invalid agent selected: {agent_id}") from err + if not await agent.async_get_backup(backup_id): + raise BackupManagerError( + f"Backup {backup_id} not found in agent {agent_id}" + ) + reader: IO[bytes] + if agent_id in self.local_backup_agents: + local_agent = self.local_backup_agents[agent_id] + path = local_agent.get_backup_path(backup_id) + reader = await self.hass.async_add_executor_job(open, path.as_posix(), "rb") + else: + backup_stream = await agent.async_download_backup(backup_id) + reader = cast(IO[bytes], AsyncIteratorReader(self.hass, backup_stream)) + try: + validate_password_stream(reader, password) + except backup_util.IncorrectPassword as err: + raise IncorrectPasswordError from err + except backup_util.UnsuppertedSecureTarVersion as err: + raise DecryptOnDowloadNotSupported from err + except backup_util.DecryptError as err: + raise BackupManagerError(str(err)) from err + class KnownBackups: """Track known backups.""" @@ -1372,7 +1420,7 @@ class CoreBackupReaderWriter(BackupReaderWriter): validate_password, path, password ) if not password_valid: - raise IncorrectPasswordError("The password provided is incorrect.") + raise IncorrectPasswordError def _write_restore_file() -> None: """Write the restore file.""" diff --git a/homeassistant/components/backup/util.py b/homeassistant/components/backup/util.py index 930625c52ca..ae0244591d8 100644 --- a/homeassistant/components/backup/util.py +++ b/homeassistant/components/backup/util.py @@ -3,13 +3,14 @@ from __future__ import annotations import asyncio +from collections.abc import AsyncIterator from pathlib import Path from queue import SimpleQueue import tarfile -from typing import cast +from typing import IO, cast import aiohttp -from securetar import SecureTarFile +from securetar import VERSION_HEADER, SecureTarFile, SecureTarReadError from homeassistant.backup_restore import password_to_key from homeassistant.core import HomeAssistant @@ -19,6 +20,22 @@ from .const import BUF_SIZE, LOGGER from .models import AddonInfo, AgentBackup, Folder +class DecryptError(Exception): + """Error during decryption.""" + + +class UnsuppertedSecureTarVersion(DecryptError): + """Unsupported securetar version.""" + + +class IncorrectPassword(DecryptError): + """Invalid password or corrupted backup.""" + + +class BackupEmpty(DecryptError): + """No tar files found in the backup.""" + + def make_backup_dir(path: Path) -> None: """Create a backup directory if it does not exist.""" path.mkdir(exist_ok=True) @@ -106,6 +123,70 @@ def validate_password(path: Path, password: str | None) -> bool: return False +class AsyncIteratorReader: + """Wrap an AsyncIterator.""" + + def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None: + """Initialize the wrapper.""" + self._hass = hass + self._stream = stream + self._buffer: bytes | None = None + self._pos: int = 0 + + async def _next(self) -> bytes | None: + """Get the next chunk from the iterator.""" + return await anext(self._stream, None) + + def read(self, n: int = -1, /) -> bytes: + """Read data from the iterator.""" + result = bytearray() + while n < 0 or len(result) < n: + if not self._buffer: + self._buffer = asyncio.run_coroutine_threadsafe( + self._next(), self._hass.loop + ).result() + self._pos = 0 + if not self._buffer: + # The stream is exhausted + break + chunk = self._buffer[self._pos : self._pos + n] + result.extend(chunk) + n -= len(chunk) + self._pos += len(chunk) + if self._pos == len(self._buffer): + self._buffer = None + return bytes(result) + + +def validate_password_stream( + input_stream: IO[bytes], + password: str | None, +) -> None: + """Decrypt a backup.""" + with ( + tarfile.open(fileobj=input_stream, mode="r|", bufsize=BUF_SIZE) as input_tar, + ): + for obj in input_tar: + if not obj.name.endswith((".tar", ".tgz", ".tar.gz")): + continue + if obj.pax_headers.get(VERSION_HEADER) != "2.0": + raise UnsuppertedSecureTarVersion + istf = SecureTarFile( + None, # Not used + gzip=False, + key=password_to_key(password) if password is not None else None, + mode="r", + fileobj=input_tar.extractfile(obj), + ) + with istf.decrypt(obj) as decrypted: + try: + decrypted.read(1) # Read a single byte to trigger the decryption + except SecureTarReadError as err: + raise IncorrectPassword from err + return + raise BackupEmpty + + async def receive_file( hass: HomeAssistant, contents: aiohttp.BodyPartReader, path: Path ) -> None: diff --git a/homeassistant/components/backup/websocket.py b/homeassistant/components/backup/websocket.py index 0139b7fdb77..1b8433e2f24 100644 --- a/homeassistant/components/backup/websocket.py +++ b/homeassistant/components/backup/websocket.py @@ -9,7 +9,11 @@ from homeassistant.core import HomeAssistant, callback from .config import ScheduleState from .const import DATA_MANAGER, LOGGER -from .manager import IncorrectPasswordError, ManagerStateEvent +from .manager import ( + DecryptOnDowloadNotSupported, + IncorrectPasswordError, + ManagerStateEvent, +) from .models import Folder @@ -24,6 +28,7 @@ def async_register_websocket_handlers(hass: HomeAssistant, with_hassio: bool) -> websocket_api.async_register_command(hass, handle_details) websocket_api.async_register_command(hass, handle_info) + websocket_api.async_register_command(hass, handle_can_decrypt_on_download) websocket_api.async_register_command(hass, handle_create) websocket_api.async_register_command(hass, handle_create_with_automatic_settings) websocket_api.async_register_command(hass, handle_delete) @@ -147,6 +152,38 @@ async def handle_restore( connection.send_result(msg["id"]) +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "backup/can_decrypt_on_download", + vol.Required("backup_id"): str, + vol.Required("agent_id"): str, + vol.Required("password"): str, + } +) +@websocket_api.async_response +async def handle_can_decrypt_on_download( + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Check if the supplied password is correct.""" + try: + await hass.data[DATA_MANAGER].async_can_decrypt_on_download( + msg["backup_id"], + agent_id=msg["agent_id"], + password=msg.get("password"), + ) + except IncorrectPasswordError: + connection.send_error(msg["id"], "password_incorrect", "Incorrect password") + except DecryptOnDowloadNotSupported: + connection.send_error( + msg["id"], "decrypt_not_supported", "Decrypt on download not supported" + ) + else: + connection.send_result(msg["id"]) + + @websocket_api.require_admin @websocket_api.websocket_command( { diff --git a/tests/components/backup/fixtures/test_backups/2bcb3113.tar b/tests/components/backup/fixtures/test_backups/2bcb3113.tar new file mode 100644 index 00000000000..8a6556634f3 Binary files /dev/null and b/tests/components/backup/fixtures/test_backups/2bcb3113.tar differ diff --git a/tests/components/backup/fixtures/test_backups/ed1608a9.tar b/tests/components/backup/fixtures/test_backups/ed1608a9.tar new file mode 100644 index 00000000000..fc928b16d1b Binary files /dev/null and b/tests/components/backup/fixtures/test_backups/ed1608a9.tar differ diff --git a/tests/components/backup/snapshots/test_websocket.ambr b/tests/components/backup/snapshots/test_websocket.ambr index 98b2f764d43..ac4e77fca41 100644 --- a/tests/components/backup/snapshots/test_websocket.ambr +++ b/tests/components/backup/snapshots/test_websocket.ambr @@ -175,6 +175,58 @@ 'type': 'result', }) # --- +# name: test_can_decrypt_on_download[backup.local-2bcb3113-hunter2] + dict({ + 'error': dict({ + 'code': 'decrypt_not_supported', + 'message': 'Decrypt on download not supported', + }), + 'id': 1, + 'success': False, + 'type': 'result', + }) +# --- +# name: test_can_decrypt_on_download[backup.local-ed1608a9-hunter2] + dict({ + 'id': 1, + 'result': None, + 'success': True, + 'type': 'result', + }) +# --- +# name: test_can_decrypt_on_download[backup.local-ed1608a9-wrong_password] + dict({ + 'error': dict({ + 'code': 'password_incorrect', + 'message': 'Incorrect password', + }), + 'id': 1, + 'success': False, + 'type': 'result', + }) +# --- +# name: test_can_decrypt_on_download[backup.local-no_such_backup-hunter2] + dict({ + 'error': dict({ + 'code': 'home_assistant_error', + 'message': 'Backup no_such_backup not found in agent backup.local', + }), + 'id': 1, + 'success': False, + 'type': 'result', + }) +# --- +# name: test_can_decrypt_on_download[no_such_agent-ed1608a9-hunter2] + dict({ + 'error': dict({ + 'code': 'home_assistant_error', + 'message': 'Invalid agent selected: no_such_agent', + }), + 'id': 1, + 'success': False, + 'type': 'result', + }) +# --- # name: test_config_info[None] dict({ 'id': 1, diff --git a/tests/components/backup/test_websocket.py b/tests/components/backup/test_websocket.py index e95481373d6..7820408f265 100644 --- a/tests/components/backup/test_websocket.py +++ b/tests/components/backup/test_websocket.py @@ -36,7 +36,7 @@ from .common import ( setup_backup_platform, ) -from tests.common import async_fire_time_changed, async_mock_service +from tests.common import async_fire_time_changed, async_mock_service, get_fixture_path from tests.typing import WebSocketGenerator BACKUP_CALL = call( @@ -2554,3 +2554,56 @@ async def test_subscribe_event( CreateBackupEvent(stage=None, state=CreateBackupState.IN_PROGRESS) ) assert await client.receive_json() == snapshot + + +@pytest.fixture +def mock_backups() -> Generator[None]: + """Fixture to setup test backups.""" + # pylint: disable-next=import-outside-toplevel + from homeassistant.components.backup import backup as core_backup + + class CoreLocalBackupAgent(core_backup.CoreLocalBackupAgent): + def __init__(self, hass: HomeAssistant) -> None: + super().__init__(hass) + self._backup_dir = get_fixture_path("test_backups", DOMAIN) + + with patch.object(core_backup, "CoreLocalBackupAgent", CoreLocalBackupAgent): + yield + + +@pytest.mark.parametrize( + ("agent_id", "backup_id", "password"), + [ + # Invalid agent or backup + ("no_such_agent", "ed1608a9", "hunter2"), + ("backup.local", "no_such_backup", "hunter2"), + # Legacy backup, which can't be streamed + ("backup.local", "2bcb3113", "hunter2"), + # New backup, which can be streamed, try with correct and wrong password + ("backup.local", "ed1608a9", "hunter2"), + ("backup.local", "ed1608a9", "wrong_password"), + ], +) +@pytest.mark.usefixtures("mock_backups") +async def test_can_decrypt_on_download( + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + snapshot: SnapshotAssertion, + agent_id: str, + backup_id: str, + password: str, +) -> None: + """Test can decrypt on download.""" + await setup_backup_integration(hass, with_hassio=False) + + client = await hass_ws_client(hass) + + await client.send_json_auto_id( + { + "type": "backup/can_decrypt_on_download", + "backup_id": backup_id, + "agent_id": agent_id, + "password": password, + } + ) + assert await client.receive_json() == snapshot