diff --git a/homeassistant/components/downloader/__init__.py b/homeassistant/components/downloader/__init__.py index eb844ad8d3f..8b33c1d7ed3 100644 --- a/homeassistant/components/downloader/__init__.py +++ b/homeassistant/components/downloader/__init__.py @@ -18,6 +18,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: # If path is relative, we assume relative to Home Assistant config dir if not os.path.isabs(download_path): download_path = hass.config.path(download_path) + hass.config_entries.async_update_entry( + entry, data={**entry.data, CONF_DOWNLOAD_DIR: download_path} + ) if not await hass.async_add_executor_job(os.path.isdir, download_path): _LOGGER.error( diff --git a/homeassistant/components/downloader/services.py b/homeassistant/components/downloader/services.py index bb1b968dd99..0ccaee232d7 100644 --- a/homeassistant/components/downloader/services.py +++ b/homeassistant/components/downloader/services.py @@ -11,6 +11,7 @@ import requests import voluptuous as vol from homeassistant.core import HomeAssistant, ServiceCall, callback +from homeassistant.exceptions import ServiceValidationError from homeassistant.helpers import config_validation as cv from homeassistant.helpers.service import async_register_admin_service from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path @@ -34,24 +35,33 @@ def download_file(service: ServiceCall) -> None: entry = service.hass.config_entries.async_loaded_entries(DOMAIN)[0] download_path = entry.data[CONF_DOWNLOAD_DIR] + url: str = service.data[ATTR_URL] + subdir: str | None = service.data.get(ATTR_SUBDIR) + target_filename: str | None = service.data.get(ATTR_FILENAME) + overwrite: bool = service.data[ATTR_OVERWRITE] + + if subdir: + # Check the path + try: + raise_if_invalid_path(subdir) + except ValueError as err: + raise ServiceValidationError( + translation_domain=DOMAIN, + translation_key="subdir_invalid", + translation_placeholders={"subdir": subdir}, + ) from err + if os.path.isabs(subdir): + raise ServiceValidationError( + translation_domain=DOMAIN, + translation_key="subdir_not_relative", + translation_placeholders={"subdir": subdir}, + ) def do_download() -> None: """Download the file.""" + final_path = None + filename = target_filename try: - url = service.data[ATTR_URL] - - subdir = service.data.get(ATTR_SUBDIR) - - filename = service.data.get(ATTR_FILENAME) - - overwrite = service.data.get(ATTR_OVERWRITE) - - if subdir: - # Check the path - raise_if_invalid_path(subdir) - - final_path = None - req = requests.get(url, stream=True, timeout=10) if req.status_code != HTTPStatus.OK: diff --git a/homeassistant/components/downloader/strings.json b/homeassistant/components/downloader/strings.json index 7db7ea459d7..98c4a0a6c82 100644 --- a/homeassistant/components/downloader/strings.json +++ b/homeassistant/components/downloader/strings.json @@ -12,6 +12,14 @@ "single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]" } }, + "exceptions": { + "subdir_invalid": { + "message": "Invalid subdirectory, got: {subdir}" + }, + "subdir_not_relative": { + "message": "Subdirectory must be relative, got: {subdir}" + } + }, "services": { "download_file": { "name": "Download file", diff --git a/tests/components/downloader/conftest.py b/tests/components/downloader/conftest.py new file mode 100644 index 00000000000..3bb63455ccc --- /dev/null +++ b/tests/components/downloader/conftest.py @@ -0,0 +1,94 @@ +"""Provide common fixtures for downloader tests.""" + +import asyncio +from pathlib import Path + +import pytest +from requests_mock import Mocker + +from homeassistant.components.downloader.const import ( + CONF_DOWNLOAD_DIR, + DOMAIN, + DOWNLOAD_COMPLETED_EVENT, + DOWNLOAD_FAILED_EVENT, +) +from homeassistant.core import Event, HomeAssistant, callback + +from tests.common import MockConfigEntry + + +@pytest.fixture +async def setup_integration( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, +) -> MockConfigEntry: + """Set up the downloader integration for testing.""" + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + return mock_config_entry + + +@pytest.fixture +def mock_config_entry( + hass: HomeAssistant, + download_dir: Path, +) -> MockConfigEntry: + """Return a mocked config entry.""" + config_entry = MockConfigEntry( + domain=DOMAIN, + data={CONF_DOWNLOAD_DIR: str(download_dir)}, + ) + config_entry.add_to_hass(hass) + return config_entry + + +@pytest.fixture +def download_dir(tmp_path: Path) -> Path: + """Return a download directory.""" + return tmp_path + + +@pytest.fixture(autouse=True) +def mock_download_request( + requests_mock: Mocker, + download_url: str, +) -> None: + """Mock the download request.""" + requests_mock.get(download_url, text="{'one': 1}") + + +@pytest.fixture +def download_url() -> str: + """Return a mock download URL.""" + return "http://example.com/file.txt" + + +@pytest.fixture +def download_completed(hass: HomeAssistant) -> asyncio.Event: + """Return an asyncio event to wait for download completion.""" + download_event = asyncio.Event() + + @callback + def download_set(event: Event[dict[str, str]]) -> None: + """Set the event when download is completed.""" + download_event.set() + + hass.bus.async_listen_once(f"{DOMAIN}_{DOWNLOAD_COMPLETED_EVENT}", download_set) + + return download_event + + +@pytest.fixture +def download_failed(hass: HomeAssistant) -> asyncio.Event: + """Return an asyncio event to wait for download failure.""" + download_event = asyncio.Event() + + @callback + def download_set(event: Event[dict[str, str]]) -> None: + """Set the event when download has failed.""" + download_event.set() + + hass.bus.async_listen_once(f"{DOMAIN}_{DOWNLOAD_FAILED_EVENT}", download_set) + + return download_event diff --git a/tests/components/downloader/test_init.py b/tests/components/downloader/test_init.py index e74eb376b39..fe001838afe 100644 --- a/tests/components/downloader/test_init.py +++ b/tests/components/downloader/test_init.py @@ -1,6 +1,8 @@ """Tests for the downloader component init.""" -from unittest.mock import patch +from pathlib import Path + +import pytest from homeassistant.components.downloader.const import ( CONF_DOWNLOAD_DIR, @@ -13,17 +15,57 @@ from homeassistant.core import HomeAssistant from tests.common import MockConfigEntry -async def test_initialization(hass: HomeAssistant) -> None: - """Test the initialization of the downloader component.""" - config_entry = MockConfigEntry( - domain=DOMAIN, - data={ - CONF_DOWNLOAD_DIR: "/test_dir", - }, - ) - config_entry.add_to_hass(hass) - with patch("os.path.isdir", return_value=True): - assert await hass.config_entries.async_setup(config_entry.entry_id) +@pytest.fixture +def download_dir(tmp_path: Path, request: pytest.FixtureRequest) -> Path: + """Return a download directory.""" + if hasattr(request, "param"): + return tmp_path / request.param + return tmp_path + + +async def test_config_entry_setup( + hass: HomeAssistant, setup_integration: MockConfigEntry +) -> None: + """Test config entry setup.""" + config_entry = setup_integration assert hass.services.has_service(DOMAIN, SERVICE_DOWNLOAD_FILE) assert config_entry.state is ConfigEntryState.LOADED + + +async def test_config_entry_setup_relative_directory( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: + """Test config entry setup with a relative download directory.""" + relative_directory = "downloads" + hass.config_entries.async_update_entry( + mock_config_entry, + data={**mock_config_entry.data, CONF_DOWNLOAD_DIR: relative_directory}, + ) + + await hass.config_entries.async_setup(mock_config_entry.entry_id) + + # The config entry will fail to set up since the directory does not exist. + # This is not relevant for this test. + assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR + assert mock_config_entry.data[CONF_DOWNLOAD_DIR] == hass.config.path( + relative_directory + ) + + +@pytest.mark.parametrize( + "download_dir", + [ + "not_existing_path", + ], + indirect=True, +) +async def test_config_entry_setup_not_existing_directory( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, +) -> None: + """Test config entry setup without existing download directory.""" + await hass.config_entries.async_setup(mock_config_entry.entry_id) + + assert not hass.services.has_service(DOMAIN, SERVICE_DOWNLOAD_FILE) + assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR diff --git a/tests/components/downloader/test_services.py b/tests/components/downloader/test_services.py new file mode 100644 index 00000000000..fbdc088021a --- /dev/null +++ b/tests/components/downloader/test_services.py @@ -0,0 +1,54 @@ +"""Test downloader services.""" + +import asyncio +from contextlib import AbstractContextManager, nullcontext as does_not_raise + +import pytest + +from homeassistant.components.downloader.const import DOMAIN +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ServiceValidationError + + +@pytest.mark.usefixtures("setup_integration") +@pytest.mark.parametrize( + ("subdir", "expected_result"), + [ + ("test", does_not_raise()), + ("test/path", does_not_raise()), + ("~test/path", pytest.raises(ServiceValidationError)), + ("~/../test/path", pytest.raises(ServiceValidationError)), + ("../test/path", pytest.raises(ServiceValidationError)), + (".../test/path", pytest.raises(ServiceValidationError)), + ("/test/path", pytest.raises(ServiceValidationError)), + ], +) +async def test_download_invalid_subdir( + hass: HomeAssistant, + download_completed: asyncio.Event, + download_failed: asyncio.Event, + download_url: str, + subdir: str, + expected_result: AbstractContextManager, +) -> None: + """Test service invalid subdirectory.""" + + async def call_service() -> None: + """Call the download service.""" + completed = hass.async_create_task(download_completed.wait()) + failed = hass.async_create_task(download_failed.wait()) + await hass.services.async_call( + DOMAIN, + "download_file", + { + "url": download_url, + "subdir": subdir, + "filename": "file.txt", + "overwrite": True, + }, + blocking=True, + ) + await asyncio.wait((completed, failed), return_when=asyncio.FIRST_COMPLETED) + + with expected_result: + await call_service()