Improve downloader service (#150046)

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
Martin Hjelmare
2025-08-05 16:12:55 +02:00
committed by GitHub
parent 37510aa316
commit fe95f6e1c5
6 changed files with 237 additions and 26 deletions

View File

@@ -18,6 +18,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# If path is relative, we assume relative to Home Assistant config dir
if not os.path.isabs(download_path):
download_path = hass.config.path(download_path)
hass.config_entries.async_update_entry(
entry, data={**entry.data, CONF_DOWNLOAD_DIR: download_path}
)
if not await hass.async_add_executor_job(os.path.isdir, download_path):
_LOGGER.error(

View File

@@ -11,6 +11,7 @@ import requests
import voluptuous as vol
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import ServiceValidationError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.service import async_register_admin_service
from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
@@ -34,24 +35,33 @@ def download_file(service: ServiceCall) -> None:
entry = service.hass.config_entries.async_loaded_entries(DOMAIN)[0]
download_path = entry.data[CONF_DOWNLOAD_DIR]
url: str = service.data[ATTR_URL]
subdir: str | None = service.data.get(ATTR_SUBDIR)
target_filename: str | None = service.data.get(ATTR_FILENAME)
overwrite: bool = service.data[ATTR_OVERWRITE]
if subdir:
# Check the path
try:
raise_if_invalid_path(subdir)
except ValueError as err:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="subdir_invalid",
translation_placeholders={"subdir": subdir},
) from err
if os.path.isabs(subdir):
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="subdir_not_relative",
translation_placeholders={"subdir": subdir},
)
def do_download() -> None:
"""Download the file."""
final_path = None
filename = target_filename
try:
url = service.data[ATTR_URL]
subdir = service.data.get(ATTR_SUBDIR)
filename = service.data.get(ATTR_FILENAME)
overwrite = service.data.get(ATTR_OVERWRITE)
if subdir:
# Check the path
raise_if_invalid_path(subdir)
final_path = None
req = requests.get(url, stream=True, timeout=10)
if req.status_code != HTTPStatus.OK:

View File

@@ -12,6 +12,14 @@
"single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]"
}
},
"exceptions": {
"subdir_invalid": {
"message": "Invalid subdirectory, got: {subdir}"
},
"subdir_not_relative": {
"message": "Subdirectory must be relative, got: {subdir}"
}
},
"services": {
"download_file": {
"name": "Download file",

View File

@@ -0,0 +1,94 @@
"""Provide common fixtures for downloader tests."""
import asyncio
from pathlib import Path
import pytest
from requests_mock import Mocker
from homeassistant.components.downloader.const import (
CONF_DOWNLOAD_DIR,
DOMAIN,
DOWNLOAD_COMPLETED_EVENT,
DOWNLOAD_FAILED_EVENT,
)
from homeassistant.core import Event, HomeAssistant, callback
from tests.common import MockConfigEntry
@pytest.fixture
async def setup_integration(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
) -> MockConfigEntry:
"""Set up the downloader integration for testing."""
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
return mock_config_entry
@pytest.fixture
def mock_config_entry(
hass: HomeAssistant,
download_dir: Path,
) -> MockConfigEntry:
"""Return a mocked config entry."""
config_entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_DOWNLOAD_DIR: str(download_dir)},
)
config_entry.add_to_hass(hass)
return config_entry
@pytest.fixture
def download_dir(tmp_path: Path) -> Path:
"""Return a download directory."""
return tmp_path
@pytest.fixture(autouse=True)
def mock_download_request(
requests_mock: Mocker,
download_url: str,
) -> None:
"""Mock the download request."""
requests_mock.get(download_url, text="{'one': 1}")
@pytest.fixture
def download_url() -> str:
"""Return a mock download URL."""
return "http://example.com/file.txt"
@pytest.fixture
def download_completed(hass: HomeAssistant) -> asyncio.Event:
"""Return an asyncio event to wait for download completion."""
download_event = asyncio.Event()
@callback
def download_set(event: Event[dict[str, str]]) -> None:
"""Set the event when download is completed."""
download_event.set()
hass.bus.async_listen_once(f"{DOMAIN}_{DOWNLOAD_COMPLETED_EVENT}", download_set)
return download_event
@pytest.fixture
def download_failed(hass: HomeAssistant) -> asyncio.Event:
"""Return an asyncio event to wait for download failure."""
download_event = asyncio.Event()
@callback
def download_set(event: Event[dict[str, str]]) -> None:
"""Set the event when download has failed."""
download_event.set()
hass.bus.async_listen_once(f"{DOMAIN}_{DOWNLOAD_FAILED_EVENT}", download_set)
return download_event

View File

@@ -1,6 +1,8 @@
"""Tests for the downloader component init."""
from unittest.mock import patch
from pathlib import Path
import pytest
from homeassistant.components.downloader.const import (
CONF_DOWNLOAD_DIR,
@@ -13,17 +15,57 @@ from homeassistant.core import HomeAssistant
from tests.common import MockConfigEntry
async def test_initialization(hass: HomeAssistant) -> None:
"""Test the initialization of the downloader component."""
config_entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_DOWNLOAD_DIR: "/test_dir",
},
)
config_entry.add_to_hass(hass)
with patch("os.path.isdir", return_value=True):
assert await hass.config_entries.async_setup(config_entry.entry_id)
@pytest.fixture
def download_dir(tmp_path: Path, request: pytest.FixtureRequest) -> Path:
"""Return a download directory."""
if hasattr(request, "param"):
return tmp_path / request.param
return tmp_path
async def test_config_entry_setup(
hass: HomeAssistant, setup_integration: MockConfigEntry
) -> None:
"""Test config entry setup."""
config_entry = setup_integration
assert hass.services.has_service(DOMAIN, SERVICE_DOWNLOAD_FILE)
assert config_entry.state is ConfigEntryState.LOADED
async def test_config_entry_setup_relative_directory(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test config entry setup with a relative download directory."""
relative_directory = "downloads"
hass.config_entries.async_update_entry(
mock_config_entry,
data={**mock_config_entry.data, CONF_DOWNLOAD_DIR: relative_directory},
)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
# The config entry will fail to set up since the directory does not exist.
# This is not relevant for this test.
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR
assert mock_config_entry.data[CONF_DOWNLOAD_DIR] == hass.config.path(
relative_directory
)
@pytest.mark.parametrize(
"download_dir",
[
"not_existing_path",
],
indirect=True,
)
async def test_config_entry_setup_not_existing_directory(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test config entry setup without existing download directory."""
await hass.config_entries.async_setup(mock_config_entry.entry_id)
assert not hass.services.has_service(DOMAIN, SERVICE_DOWNLOAD_FILE)
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR

View File

@@ -0,0 +1,54 @@
"""Test downloader services."""
import asyncio
from contextlib import AbstractContextManager, nullcontext as does_not_raise
import pytest
from homeassistant.components.downloader.const import DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ServiceValidationError
@pytest.mark.usefixtures("setup_integration")
@pytest.mark.parametrize(
("subdir", "expected_result"),
[
("test", does_not_raise()),
("test/path", does_not_raise()),
("~test/path", pytest.raises(ServiceValidationError)),
("~/../test/path", pytest.raises(ServiceValidationError)),
("../test/path", pytest.raises(ServiceValidationError)),
(".../test/path", pytest.raises(ServiceValidationError)),
("/test/path", pytest.raises(ServiceValidationError)),
],
)
async def test_download_invalid_subdir(
hass: HomeAssistant,
download_completed: asyncio.Event,
download_failed: asyncio.Event,
download_url: str,
subdir: str,
expected_result: AbstractContextManager,
) -> None:
"""Test service invalid subdirectory."""
async def call_service() -> None:
"""Call the download service."""
completed = hass.async_create_task(download_completed.wait())
failed = hass.async_create_task(download_failed.wait())
await hass.services.async_call(
DOMAIN,
"download_file",
{
"url": download_url,
"subdir": subdir,
"filename": "file.txt",
"overwrite": True,
},
blocking=True,
)
await asyncio.wait((completed, failed), return_when=asyncio.FIRST_COMPLETED)
with expected_result:
await call_service()