Allow setting password for backups (#110630)

* Allow setting password for backups

* use is_hassio from helpers

* move it

* Fix getting psw

* Fix restoring with psw

* Address review comments

* Improve docstring

* Adjust kitchen sink

* Adjust

---------

Co-authored-by: Erik <erik@montnemery.com>
This commit is contained in:
Joakim Sørensen
2024-11-14 12:53:28 +01:00
committed by GitHub
parent d1185f8754
commit 5a69488630
14 changed files with 295 additions and 41 deletions

View File

@@ -1,6 +1,7 @@
"""Home Assistant module to handle restoring backups.""" """Home Assistant module to handle restoring backups."""
from dataclasses import dataclass from dataclasses import dataclass
import hashlib
import json import json
import logging import logging
from pathlib import Path from pathlib import Path
@@ -24,6 +25,18 @@ class RestoreBackupFileContent:
"""Definition for restore backup file content.""" """Definition for restore backup file content."""
backup_file_path: Path backup_file_path: Path
password: str | None = None
def password_to_key(password: str) -> bytes:
"""Generate a AES Key from password.
Matches the implementation in supervisor.backups.utils.password_to_key.
"""
key: bytes = password.encode()
for _ in range(100):
key = hashlib.sha256(key).digest()
return key[:16]
def restore_backup_file_content(config_dir: Path) -> RestoreBackupFileContent | None: def restore_backup_file_content(config_dir: Path) -> RestoreBackupFileContent | None:
@@ -32,7 +45,8 @@ def restore_backup_file_content(config_dir: Path) -> RestoreBackupFileContent |
try: try:
instruction_content = json.loads(instruction_path.read_text(encoding="utf-8")) instruction_content = json.loads(instruction_path.read_text(encoding="utf-8"))
return RestoreBackupFileContent( return RestoreBackupFileContent(
backup_file_path=Path(instruction_content["path"]) backup_file_path=Path(instruction_content["path"]),
password=instruction_content.get("password"),
) )
except (FileNotFoundError, json.JSONDecodeError): except (FileNotFoundError, json.JSONDecodeError):
return None return None
@@ -54,7 +68,11 @@ def _clear_configuration_directory(config_dir: Path) -> None:
shutil.rmtree(entrypath) shutil.rmtree(entrypath)
def _extract_backup(config_dir: Path, backup_file_path: Path) -> None: def _extract_backup(
config_dir: Path,
backup_file_path: Path,
password: str | None = None,
) -> None:
"""Extract the backup file to the config directory.""" """Extract the backup file to the config directory."""
with ( with (
TemporaryDirectory() as tempdir, TemporaryDirectory() as tempdir,
@@ -88,22 +106,28 @@ def _extract_backup(config_dir: Path, backup_file_path: Path) -> None:
f"homeassistant.tar{'.gz' if backup_meta["compressed"] else ''}", f"homeassistant.tar{'.gz' if backup_meta["compressed"] else ''}",
), ),
gzip=backup_meta["compressed"], gzip=backup_meta["compressed"],
key=password_to_key(password) if password is not None else None,
mode="r", mode="r",
) as istf: ) as istf:
for member in istf.getmembers():
if member.name == "data":
continue
member.name = member.name.replace("data/", "")
_clear_configuration_directory(config_dir)
istf.extractall( istf.extractall(
path=config_dir, path=Path(
members=[ tempdir,
member "homeassistant",
for member in securetar.secure_path(istf) ),
if member.name != "data" members=securetar.secure_path(istf),
],
filter="fully_trusted", filter="fully_trusted",
) )
_clear_configuration_directory(config_dir)
shutil.copytree(
Path(
tempdir,
"homeassistant",
"data",
),
config_dir,
dirs_exist_ok=True,
ignore=shutil.ignore_patterns(*(KEEP_PATHS)),
)
def restore_backup(config_dir_path: str) -> bool: def restore_backup(config_dir_path: str) -> bool:
@@ -119,7 +143,11 @@ def restore_backup(config_dir_path: str) -> bool:
backup_file_path = restore_content.backup_file_path backup_file_path = restore_content.backup_file_path
_LOGGER.info("Restoring %s", backup_file_path) _LOGGER.info("Restoring %s", backup_file_path)
try: try:
_extract_backup(config_dir, backup_file_path) _extract_backup(
config_dir=config_dir,
backup_file_path=backup_file_path,
password=restore_content.password,
)
except FileNotFoundError as err: except FileNotFoundError as err:
raise ValueError(f"Backup file {backup_file_path} does not exist") from err raise ValueError(f"Backup file {backup_file_path} does not exist") from err
_LOGGER.info("Restore complete, restarting") _LOGGER.info("Restore complete, restarting")

View File

@@ -1,5 +1,8 @@
"""The Backup integration.""" """The Backup integration."""
import voluptuous as vol
from homeassistant.const import CONF_PASSWORD
from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.hassio import is_hassio from homeassistant.helpers.hassio import is_hassio
@@ -20,6 +23,8 @@ __all__ = [
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
SERVICE_CREATE_SCHEMA = vol.Schema({vol.Optional(CONF_PASSWORD): str})
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Backup integration.""" """Set up the Backup integration."""
@@ -45,11 +50,17 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
folders_included=None, folders_included=None,
name=None, name=None,
on_progress=None, on_progress=None,
password=call.data.get(CONF_PASSWORD),
) )
if backup_task := backup_manager.backup_task: if backup_task := backup_manager.backup_task:
await backup_task await backup_task
hass.services.async_register(DOMAIN, "create", async_handle_create_service) hass.services.async_register(
DOMAIN,
"create",
async_handle_create_service,
schema=SERVICE_CREATE_SCHEMA,
)
async_register_http_views(hass) async_register_http_views(hass)

View File

@@ -22,7 +22,7 @@ import aiohttp
from securetar import SecureTarFile, atomic_contents_add from securetar import SecureTarFile, atomic_contents_add
from typing_extensions import TypeVar from typing_extensions import TypeVar
from homeassistant.backup_restore import RESTORE_BACKUP_FILE from homeassistant.backup_restore import RESTORE_BACKUP_FILE, password_to_key
from homeassistant.const import __version__ as HAVERSION from homeassistant.const import __version__ as HAVERSION
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -173,7 +173,13 @@ class BaseBackupManager(abc.ABC, Generic[_BackupT]):
self.loaded_platforms = True self.loaded_platforms = True
@abc.abstractmethod @abc.abstractmethod
async def async_restore_backup(self, slug: str, **kwargs: Any) -> None: async def async_restore_backup(
self,
slug: str,
*,
password: str | None = None,
**kwargs: Any,
) -> None:
"""Restore a backup.""" """Restore a backup."""
@abc.abstractmethod @abc.abstractmethod
@@ -185,6 +191,7 @@ class BaseBackupManager(abc.ABC, Generic[_BackupT]):
folders_included: list[str] | None, folders_included: list[str] | None,
name: str | None, name: str | None,
on_progress: Callable[[BackupProgress], None] | None, on_progress: Callable[[BackupProgress], None] | None,
password: str | None,
**kwargs: Any, **kwargs: Any,
) -> NewBackup: ) -> NewBackup:
"""Initiate generating a backup. """Initiate generating a backup.
@@ -252,6 +259,7 @@ class BackupManager(BaseBackupManager[Backup]):
date=backup.date, date=backup.date,
slug=backup.slug, slug=backup.slug,
name=backup.name, name=backup.name,
protected=backup.protected,
), ),
) )
for agent in self.backup_agents.values() for agent in self.backup_agents.values()
@@ -284,6 +292,7 @@ class BackupManager(BaseBackupManager[Backup]):
date=cast(str, data["date"]), date=cast(str, data["date"]),
path=backup_path, path=backup_path,
size=round(backup_path.stat().st_size / 1_048_576, 2), size=round(backup_path.stat().st_size / 1_048_576, 2),
protected=cast(bool, data.get("protected", False)),
) )
backups[backup.slug] = backup backups[backup.slug] = backup
except (OSError, TarError, json.JSONDecodeError, KeyError) as err: except (OSError, TarError, json.JSONDecodeError, KeyError) as err:
@@ -393,6 +402,7 @@ class BackupManager(BaseBackupManager[Backup]):
folders_included: list[str] | None, folders_included: list[str] | None,
name: str | None, name: str | None,
on_progress: Callable[[BackupProgress], None] | None, on_progress: Callable[[BackupProgress], None] | None,
password: str | None,
**kwargs: Any, **kwargs: Any,
) -> NewBackup: ) -> NewBackup:
"""Initiate generating a backup.""" """Initiate generating a backup."""
@@ -409,6 +419,7 @@ class BackupManager(BaseBackupManager[Backup]):
date_str=date_str, date_str=date_str,
folders_included=folders_included, folders_included=folders_included,
on_progress=on_progress, on_progress=on_progress,
password=password,
slug=slug, slug=slug,
), ),
name="backup_manager_create_backup", name="backup_manager_create_backup",
@@ -425,6 +436,7 @@ class BackupManager(BaseBackupManager[Backup]):
date_str: str, date_str: str,
folders_included: list[str] | None, folders_included: list[str] | None,
on_progress: Callable[[BackupProgress], None] | None, on_progress: Callable[[BackupProgress], None] | None,
password: str | None,
slug: str, slug: str,
) -> Backup: ) -> Backup:
"""Generate a backup.""" """Generate a backup."""
@@ -443,13 +455,16 @@ class BackupManager(BaseBackupManager[Backup]):
"version": HAVERSION, "version": HAVERSION,
}, },
"compressed": True, "compressed": True,
"protected": password is not None,
} }
tar_file_path = Path(self.backup_dir, f"{backup_data['slug']}.tar") tar_file_path = Path(self.backup_dir, f"{backup_data['slug']}.tar")
size_in_bytes = await self.hass.async_add_executor_job( size_in_bytes = await self.hass.async_add_executor_job(
self._mkdir_and_generate_backup_contents, self._mkdir_and_generate_backup_contents,
tar_file_path, tar_file_path,
backup_data, backup_data,
database_included, database_included,
password,
) )
backup = Backup( backup = Backup(
slug=slug, slug=slug,
@@ -457,6 +472,7 @@ class BackupManager(BaseBackupManager[Backup]):
date=date_str, date=date_str,
path=tar_file_path, path=tar_file_path,
size=round(size_in_bytes / 1_048_576, 2), size=round(size_in_bytes / 1_048_576, 2),
protected=password is not None,
) )
if self.loaded_backups: if self.loaded_backups:
self.backups[slug] = backup self.backups[slug] = backup
@@ -474,6 +490,7 @@ class BackupManager(BaseBackupManager[Backup]):
tar_file_path: Path, tar_file_path: Path,
backup_data: dict[str, Any], backup_data: dict[str, Any],
database_included: bool, database_included: bool,
password: str | None = None,
) -> int: ) -> int:
"""Generate backup contents and return the size.""" """Generate backup contents and return the size."""
if not self.backup_dir.exists(): if not self.backup_dir.exists():
@@ -495,7 +512,9 @@ class BackupManager(BaseBackupManager[Backup]):
tar_info.mtime = int(time.time()) tar_info.mtime = int(time.time())
outer_secure_tarfile_tarfile.addfile(tar_info, fileobj=fileobj) outer_secure_tarfile_tarfile.addfile(tar_info, fileobj=fileobj)
with outer_secure_tarfile.create_inner_tar( with outer_secure_tarfile.create_inner_tar(
"./homeassistant.tar.gz", gzip=True "./homeassistant.tar.gz",
gzip=True,
key=password_to_key(password) if password is not None else None,
) as core_tar: ) as core_tar:
atomic_contents_add( atomic_contents_add(
tar_file=core_tar, tar_file=core_tar,
@@ -503,10 +522,15 @@ class BackupManager(BaseBackupManager[Backup]):
excludes=excludes, excludes=excludes,
arcname="data", arcname="data",
) )
return tar_file_path.stat().st_size return tar_file_path.stat().st_size
async def async_restore_backup(self, slug: str, **kwargs: Any) -> None: async def async_restore_backup(
self,
slug: str,
*,
password: str | None = None,
**kwargs: Any,
) -> None:
"""Restore a backup. """Restore a backup.
This will write the restore information to .HA_RESTORE which This will write the restore information to .HA_RESTORE which
@@ -518,7 +542,7 @@ class BackupManager(BaseBackupManager[Backup]):
def _write_restore_file() -> None: def _write_restore_file() -> None:
"""Write the restore file.""" """Write the restore file."""
Path(self.hass.config.path(RESTORE_BACKUP_FILE)).write_text( Path(self.hass.config.path(RESTORE_BACKUP_FILE)).write_text(
json.dumps({"path": backup.path.as_posix()}), json.dumps({"path": backup.path.as_posix(), "password": password}),
encoding="utf-8", encoding="utf-8",
) )

View File

@@ -8,9 +8,10 @@ class BaseBackup:
"""Base backup class.""" """Base backup class."""
date: str date: str
name: str
protected: bool
slug: str slug: str
size: float size: float
name: str
def as_dict(self) -> dict: def as_dict(self) -> dict:
"""Return a dict representation of this backup.""" """Return a dict representation of this backup."""
@@ -26,3 +27,4 @@ class BackupUploadMetadata:
size: float # The size of the backup (in bytes) size: float # The size of the backup (in bytes)
name: str # The name of the backup name: str # The name of the backup
homeassistant: str # The version of Home Assistant that created the backup homeassistant: str # The version of Home Assistant that created the backup
protected: bool # If the backup is protected

View File

@@ -1 +1,7 @@
create: create:
fields:
password:
required: false
selector:
text:
type: password

View File

@@ -2,7 +2,13 @@
"services": { "services": {
"create": { "create": {
"name": "Create backup", "name": "Create backup",
"description": "Creates a new backup." "description": "Creates a new backup.",
"fields": {
"password": {
"name": "[%key:common::config_flow::data::password%]",
"description": "Password protect the backup"
}
}
} }
} }
} }

View File

@@ -98,6 +98,7 @@ async def handle_remove(
{ {
vol.Required("type"): "backup/restore", vol.Required("type"): "backup/restore",
vol.Required("slug"): str, vol.Required("slug"): str,
vol.Optional("password"): str,
} }
) )
@websocket_api.async_response @websocket_api.async_response
@@ -107,7 +108,10 @@ async def handle_restore(
msg: dict[str, Any], msg: dict[str, Any],
) -> None: ) -> None:
"""Restore a backup.""" """Restore a backup."""
await hass.data[DATA_MANAGER].async_restore_backup(msg["slug"]) await hass.data[DATA_MANAGER].async_restore_backup(
slug=msg["slug"],
password=msg.get("password"),
)
connection.send_result(msg["id"]) connection.send_result(msg["id"])
@@ -119,6 +123,7 @@ async def handle_restore(
vol.Optional("database_included", default=True): bool, vol.Optional("database_included", default=True): bool,
vol.Optional("folders_included"): [str], vol.Optional("folders_included"): [str],
vol.Optional("name"): str, vol.Optional("name"): str,
vol.Optional("password"): str,
} }
) )
@websocket_api.async_response @websocket_api.async_response
@@ -138,6 +143,7 @@ async def handle_create(
folders_included=msg.get("folders_included"), folders_included=msg.get("folders_included"),
name=msg.get("name"), name=msg.get("name"),
on_progress=on_progress, on_progress=on_progress,
password=msg.get("password"),
) )
connection.send_result(msg["id"], backup) connection.send_result(msg["id"], backup)

View File

@@ -34,6 +34,7 @@ class KitchenSinkBackupAgent(BackupAgent):
UploadedBackup( UploadedBackup(
id="def456", id="def456",
name="Kitchen sink syncer", name="Kitchen sink syncer",
protected=False,
slug="abc123", slug="abc123",
size=1234, size=1234,
date="1970-01-01T00:00:00Z", date="1970-01-01T00:00:00Z",
@@ -63,6 +64,7 @@ class KitchenSinkBackupAgent(BackupAgent):
UploadedBackup( UploadedBackup(
id=uuid4().hex, id=uuid4().hex,
name=metadata.name, name=metadata.name,
protected=metadata.protected,
slug=metadata.slug, slug=metadata.slug,
size=metadata.size, size=metadata.size,
date=metadata.date, date=metadata.date,

View File

@@ -20,6 +20,7 @@ TEST_BACKUP = Backup(
date="1970-01-01T00:00:00.000Z", date="1970-01-01T00:00:00.000Z",
path=Path("abc123.tar"), path=Path("abc123.tar"),
size=0.0, size=0.0,
protected=False,
) )
@@ -49,10 +50,11 @@ class BackupAgentTest(BackupAgent):
return [ return [
UploadedBackup( UploadedBackup(
id="abc123", id="abc123",
name="Test",
slug="abc123",
size=13.37,
date="1970-01-01T00:00:00Z", date="1970-01-01T00:00:00Z",
name="Test",
protected=False,
size=13.37,
slug="abc123",
) )
] ]

View File

@@ -76,6 +76,7 @@
'date': '1970-01-01T00:00:00Z', 'date': '1970-01-01T00:00:00Z',
'id': 'abc123', 'id': 'abc123',
'name': 'Test', 'name': 'Test',
'protected': False,
'size': 13.37, 'size': 13.37,
'slug': 'abc123', 'slug': 'abc123',
}), }),
@@ -93,6 +94,7 @@
'date': '1970-01-01T00:00:00Z', 'date': '1970-01-01T00:00:00Z',
'id': 'abc123', 'id': 'abc123',
'name': 'Test', 'name': 'Test',
'protected': False,
'size': 13.37, 'size': 13.37,
'slug': 'abc123', 'slug': 'abc123',
}), }),
@@ -353,6 +355,7 @@
'date': '1970-01-01T00:00:00.000Z', 'date': '1970-01-01T00:00:00.000Z',
'name': 'Test', 'name': 'Test',
'path': 'abc123.tar', 'path': 'abc123.tar',
'protected': False,
'size': 0.0, 'size': 0.0,
'slug': 'abc123', 'slug': 'abc123',
}), }),
@@ -371,7 +374,7 @@
'type': 'result', 'type': 'result',
}) })
# --- # ---
# name: test_generate[with_hassio] # name: test_generate[with_hassio-None]
dict({ dict({
'error': dict({ 'error': dict({
'code': 'unknown_command', 'code': 'unknown_command',
@@ -382,7 +385,29 @@
'type': 'result', 'type': 'result',
}) })
# --- # ---
# name: test_generate[without_hassio] # name: test_generate[with_hassio-data1]
dict({
'error': dict({
'code': 'unknown_command',
'message': 'Unknown command.',
}),
'id': 1,
'success': False,
'type': 'result',
})
# ---
# name: test_generate[with_hassio-data2]
dict({
'error': dict({
'code': 'unknown_command',
'message': 'Unknown command.',
}),
'id': 1,
'success': False,
'type': 'result',
})
# ---
# name: test_generate[without_hassio-None]
dict({ dict({
'id': 1, 'id': 1,
'result': dict({ 'result': dict({
@@ -392,7 +417,49 @@
'type': 'result', 'type': 'result',
}) })
# --- # ---
# name: test_generate[without_hassio].1 # name: test_generate[without_hassio-None].1
dict({
'event': dict({
'done': True,
'stage': None,
'success': True,
}),
'id': 1,
'type': 'event',
})
# ---
# name: test_generate[without_hassio-data1]
dict({
'id': 1,
'result': dict({
'slug': '27f5c632',
}),
'success': True,
'type': 'result',
})
# ---
# name: test_generate[without_hassio-data1].1
dict({
'event': dict({
'done': True,
'stage': None,
'success': True,
}),
'id': 1,
'type': 'event',
})
# ---
# name: test_generate[without_hassio-data2]
dict({
'id': 1,
'result': dict({
'slug': '27f5c632',
}),
'success': True,
'type': 'result',
})
# ---
# name: test_generate[without_hassio-data2].1
dict({ dict({
'event': dict({ 'event': dict({
'done': True, 'done': True,
@@ -444,6 +511,7 @@
'date': '1970-01-01T00:00:00.000Z', 'date': '1970-01-01T00:00:00.000Z',
'name': 'Test', 'name': 'Test',
'path': 'abc123.tar', 'path': 'abc123.tar',
'protected': False,
'size': 0.0, 'size': 0.0,
'slug': 'abc123', 'slug': 'abc123',
}), }),

View File

@@ -1,5 +1,6 @@
"""Tests for the Backup integration.""" """Tests for the Backup integration."""
from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@@ -26,8 +27,10 @@ async def test_setup_with_hassio(
) in caplog.text ) in caplog.text
@pytest.mark.parametrize("service_data", [None, {}, {"password": "abc123"}])
async def test_create_service( async def test_create_service(
hass: HomeAssistant, hass: HomeAssistant,
service_data: dict[str, Any] | None,
) -> None: ) -> None:
"""Test generate backup.""" """Test generate backup."""
await setup_backup_integration(hass) await setup_backup_integration(hass)
@@ -39,6 +42,7 @@ async def test_create_service(
DOMAIN, DOMAIN,
"create", "create",
blocking=True, blocking=True,
service_data=service_data,
) )
assert generate_backup.called assert generate_backup.called

View File

@@ -38,6 +38,7 @@ async def _mock_backup_generation(
*, *,
database_included: bool = True, database_included: bool = True,
name: str | None = "Core 2025.1.0", name: str | None = "Core 2025.1.0",
password: str | None = None,
) -> None: ) -> None:
"""Mock backup generator.""" """Mock backup generator."""
@@ -54,6 +55,7 @@ async def _mock_backup_generation(
folders_included=[], folders_included=[],
name=name, name=name,
on_progress=on_progress, on_progress=on_progress,
password=password,
) )
assert manager.backup_task is not None assert manager.backup_task is not None
assert progress == [] assert progress == []
@@ -73,6 +75,7 @@ async def _mock_backup_generation(
"version": "2025.1.0", "version": "2025.1.0",
}, },
"name": name, "name": name,
"protected": bool(password),
"slug": ANY, "slug": ANY,
"type": "partial", "type": "partial",
} }
@@ -199,6 +202,7 @@ async def test_async_create_backup_when_backing_up(hass: HomeAssistant) -> None:
folders_included=[], folders_included=[],
name=None, name=None,
on_progress=None, on_progress=None,
password=None,
) )
event.set() event.set()
@@ -206,7 +210,12 @@ async def test_async_create_backup_when_backing_up(hass: HomeAssistant) -> None:
@pytest.mark.usefixtures("mock_backup_generation") @pytest.mark.usefixtures("mock_backup_generation")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"params", "params",
[{}, {"database_included": True, "name": "abc123"}, {"database_included": False}], [
{},
{"database_included": True, "name": "abc123"},
{"database_included": False},
{"password": "abc123"},
],
) )
async def test_async_create_backup( async def test_async_create_backup(
hass: HomeAssistant, hass: HomeAssistant,
@@ -228,6 +237,10 @@ async def test_async_create_backup(
assert "Loaded 0 platforms" in caplog.text assert "Loaded 0 platforms" in caplog.text
assert "Loaded 0 agents" in caplog.text assert "Loaded 0 agents" in caplog.text
assert len(manager.backups) == 1
backup = list(manager.backups.values())[0]
assert backup.protected is bool(params.get("password"))
async def test_loading_platforms( async def test_loading_platforms(
hass: HomeAssistant, hass: HomeAssistant,
@@ -351,6 +364,7 @@ async def test_syncing_backup(
date=backup.date, date=backup.date,
homeassistant="2025.1.0", homeassistant="2025.1.0",
name=backup.name, name=backup.name,
protected=backup.protected,
size=backup.size, size=backup.size,
slug=backup.slug, slug=backup.slug,
) )
@@ -415,6 +429,7 @@ async def test_syncing_backup_with_exception(
date=backup.date, date=backup.date,
homeassistant="2025.1.0", homeassistant="2025.1.0",
name=backup.name, name=backup.name,
protected=backup.protected,
size=backup.size, size=backup.size,
slug=backup.slug, slug=backup.slug,
) )
@@ -600,7 +615,32 @@ async def test_async_trigger_restore(
patch("homeassistant.core.ServiceRegistry.async_call") as mocked_service_call, patch("homeassistant.core.ServiceRegistry.async_call") as mocked_service_call,
): ):
await manager.async_restore_backup(TEST_BACKUP.slug) await manager.async_restore_backup(TEST_BACKUP.slug)
assert mocked_write_text.call_args[0][0] == '{"path": "abc123.tar"}' assert (
mocked_write_text.call_args[0][0]
== '{"path": "abc123.tar", "password": null}'
)
assert mocked_service_call.called
async def test_async_trigger_restore_with_password(
hass: HomeAssistant,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test trigger restore."""
manager = BackupManager(hass)
manager.loaded_backups = True
manager.backups = {TEST_BACKUP.slug: TEST_BACKUP}
with (
patch("pathlib.Path.exists", return_value=True),
patch("pathlib.Path.write_text") as mocked_write_text,
patch("homeassistant.core.ServiceRegistry.async_call") as mocked_service_call,
):
await manager.async_restore_backup(slug=TEST_BACKUP.slug, password="abc123")
assert (
mocked_write_text.call_args[0][0]
== '{"path": "abc123.tar", "password": "abc123"}'
)
assert mocked_service_call.called assert mocked_service_call.called

View File

@@ -1,6 +1,7 @@
"""Tests for the Backup integration.""" """Tests for the Backup integration."""
from pathlib import Path from pathlib import Path
from typing import Any
from unittest.mock import ANY, AsyncMock, patch from unittest.mock import ANY, AsyncMock, patch
from freezegun.api import FrozenDateTimeFactory from freezegun.api import FrozenDateTimeFactory
@@ -126,6 +127,14 @@ async def test_remove(
assert await client.receive_json() == snapshot assert await client.receive_json() == snapshot
@pytest.mark.parametrize(
"data",
[
None,
{},
{"password": "abc123"},
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("with_hassio", "number_of_messages"), ("with_hassio", "number_of_messages"),
[ [
@@ -137,6 +146,7 @@ async def test_remove(
async def test_generate( async def test_generate(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
data: dict[str, Any] | None,
freezer: FrozenDateTimeFactory, freezer: FrozenDateTimeFactory,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
with_hassio: bool, with_hassio: bool,
@@ -149,7 +159,7 @@ async def test_generate(
freezer.move_to("2024-11-13 12:01:00+01:00") freezer.move_to("2024-11-13 12:01:00+01:00")
await hass.async_block_till_done() await hass.async_block_till_done()
await client.send_json_auto_id({"type": "backup/generate"}) await client.send_json_auto_id({"type": "backup/generate", **(data or {})})
for _ in range(number_of_messages): for _ in range(number_of_messages):
assert await client.receive_json() == snapshot assert await client.receive_json() == snapshot
@@ -203,6 +213,7 @@ async def test_generate_without_hassio(
"folders_included": None, "folders_included": None,
"name": None, "name": None,
"on_progress": ANY, "on_progress": ANY,
"password": None,
} }
| expected_extra_call_params | expected_extra_call_params
) )

View File

@@ -19,7 +19,23 @@ from .common import get_test_config_dir
( (
None, None,
'{"path": "test"}', '{"path": "test"}',
backup_restore.RestoreBackupFileContent(backup_file_path=Path("test")), backup_restore.RestoreBackupFileContent(
backup_file_path=Path("test"), password=None
),
),
(
None,
'{"path": "test", "password": "psw"}',
backup_restore.RestoreBackupFileContent(
backup_file_path=Path("test"), password="psw"
),
),
(
None,
'{"path": "test", "password": null}',
backup_restore.RestoreBackupFileContent(
backup_file_path=Path("test"), password=None
),
), ),
], ],
) )
@@ -155,15 +171,17 @@ def test_removal_of_current_configuration_when_restoring() -> None:
return_value=[x["path"] for x in mock_config_dir], return_value=[x["path"] for x in mock_config_dir],
), ),
mock.patch("pathlib.Path.unlink") as unlink_mock, mock.patch("pathlib.Path.unlink") as unlink_mock,
mock.patch("shutil.rmtree") as rmtreemock, mock.patch("shutil.copytree") as copytree_mock,
mock.patch("shutil.rmtree") as rmtree_mock,
): ):
assert backup_restore.restore_backup(config_dir) is True assert backup_restore.restore_backup(config_dir) is True
assert unlink_mock.call_count == 2 assert unlink_mock.call_count == 2
assert copytree_mock.call_count == 1
assert ( assert (
rmtreemock.call_count == 1 rmtree_mock.call_count == 1
) # We have 2 directories in the config directory, but backups is kept ) # We have 2 directories in the config directory, but backups is kept
removed_directories = {Path(call.args[0]) for call in rmtreemock.mock_calls} removed_directories = {Path(call.args[0]) for call in rmtree_mock.mock_calls}
assert removed_directories == {Path(config_dir, "www")} assert removed_directories == {Path(config_dir, "www")}
@@ -177,8 +195,8 @@ def test_extracting_the_contents_of_a_backup_file() -> None:
getmembers_mock = mock.MagicMock( getmembers_mock = mock.MagicMock(
return_value=[ return_value=[
tarfile.TarInfo(name="../data/test"),
tarfile.TarInfo(name="data"), tarfile.TarInfo(name="data"),
tarfile.TarInfo(name="data/../test"),
tarfile.TarInfo(name="data/.HA_VERSION"), tarfile.TarInfo(name="data/.HA_VERSION"),
tarfile.TarInfo(name="data/.storage"), tarfile.TarInfo(name="data/.storage"),
tarfile.TarInfo(name="data/www"), tarfile.TarInfo(name="data/www"),
@@ -190,7 +208,7 @@ def test_extracting_the_contents_of_a_backup_file() -> None:
mock.patch( mock.patch(
"homeassistant.backup_restore.restore_backup_file_content", "homeassistant.backup_restore.restore_backup_file_content",
return_value=backup_restore.RestoreBackupFileContent( return_value=backup_restore.RestoreBackupFileContent(
backup_file_path=backup_file_path backup_file_path=backup_file_path,
), ),
), ),
mock.patch( mock.patch(
@@ -205,11 +223,37 @@ def test_extracting_the_contents_of_a_backup_file() -> None:
mock.patch("pathlib.Path.read_text", _patched_path_read_text), mock.patch("pathlib.Path.read_text", _patched_path_read_text),
mock.patch("pathlib.Path.is_file", return_value=False), mock.patch("pathlib.Path.is_file", return_value=False),
mock.patch("pathlib.Path.iterdir", return_value=[]), mock.patch("pathlib.Path.iterdir", return_value=[]),
mock.patch("shutil.copytree"),
): ):
assert backup_restore.restore_backup(config_dir) is True assert backup_restore.restore_backup(config_dir) is True
assert getmembers_mock.call_count == 1
assert extractall_mock.call_count == 2 assert extractall_mock.call_count == 2
assert { assert {
member.name for member in extractall_mock.mock_calls[-1].kwargs["members"] member.name for member in extractall_mock.mock_calls[-1].kwargs["members"]
} == {".HA_VERSION", ".storage", "www"} } == {"data", "data/.HA_VERSION", "data/.storage", "data/www"}
@pytest.mark.parametrize(
("password", "expected"),
[
("test", b"\xf0\x9b\xb9\x1f\xdc,\xff\xd5x\xd6\xd6\x8fz\x19.\x0f"),
("lorem ipsum...", b"#\xe0\xfc\xe0\xdb?_\x1f,$\rQ\xf4\xf5\xd8\xfb"),
],
)
def test_pw_to_key(password: str | None, expected: bytes | None) -> None:
"""Test password to key conversion."""
assert backup_restore.password_to_key(password) == expected
@pytest.mark.parametrize(
("password", "expected"),
[
(None, None),
("test", b"\xf0\x9b\xb9\x1f\xdc,\xff\xd5x\xd6\xd6\x8fz\x19.\x0f"),
("lorem ipsum...", b"#\xe0\xfc\xe0\xdb?_\x1f,$\rQ\xf4\xf5\xd8\xfb"),
],
)
def test_pw_to_key_none(password: str | None, expected: bytes | None) -> None:
"""Test password to key conversion."""
with pytest.raises(AttributeError):
backup_restore.password_to_key(None)