mirror of
https://github.com/home-assistant/core.git
synced 2025-09-06 05:11:35 +02:00
Improve downloader service (#150046)
Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
This commit is contained in:
@@ -18,6 +18,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
# If path is relative, we assume relative to Home Assistant config dir
|
||||
if not os.path.isabs(download_path):
|
||||
download_path = hass.config.path(download_path)
|
||||
hass.config_entries.async_update_entry(
|
||||
entry, data={**entry.data, CONF_DOWNLOAD_DIR: download_path}
|
||||
)
|
||||
|
||||
if not await hass.async_add_executor_job(os.path.isdir, download_path):
|
||||
_LOGGER.error(
|
||||
|
@@ -11,6 +11,7 @@ import requests
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, ServiceCall, callback
|
||||
from homeassistant.exceptions import ServiceValidationError
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.service import async_register_admin_service
|
||||
from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
|
||||
@@ -34,24 +35,33 @@ def download_file(service: ServiceCall) -> None:
|
||||
|
||||
entry = service.hass.config_entries.async_loaded_entries(DOMAIN)[0]
|
||||
download_path = entry.data[CONF_DOWNLOAD_DIR]
|
||||
url: str = service.data[ATTR_URL]
|
||||
subdir: str | None = service.data.get(ATTR_SUBDIR)
|
||||
target_filename: str | None = service.data.get(ATTR_FILENAME)
|
||||
overwrite: bool = service.data[ATTR_OVERWRITE]
|
||||
|
||||
if subdir:
|
||||
# Check the path
|
||||
try:
|
||||
raise_if_invalid_path(subdir)
|
||||
except ValueError as err:
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="subdir_invalid",
|
||||
translation_placeholders={"subdir": subdir},
|
||||
) from err
|
||||
if os.path.isabs(subdir):
|
||||
raise ServiceValidationError(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="subdir_not_relative",
|
||||
translation_placeholders={"subdir": subdir},
|
||||
)
|
||||
|
||||
def do_download() -> None:
|
||||
"""Download the file."""
|
||||
final_path = None
|
||||
filename = target_filename
|
||||
try:
|
||||
url = service.data[ATTR_URL]
|
||||
|
||||
subdir = service.data.get(ATTR_SUBDIR)
|
||||
|
||||
filename = service.data.get(ATTR_FILENAME)
|
||||
|
||||
overwrite = service.data.get(ATTR_OVERWRITE)
|
||||
|
||||
if subdir:
|
||||
# Check the path
|
||||
raise_if_invalid_path(subdir)
|
||||
|
||||
final_path = None
|
||||
|
||||
req = requests.get(url, stream=True, timeout=10)
|
||||
|
||||
if req.status_code != HTTPStatus.OK:
|
||||
|
@@ -12,6 +12,14 @@
|
||||
"single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]"
|
||||
}
|
||||
},
|
||||
"exceptions": {
|
||||
"subdir_invalid": {
|
||||
"message": "Invalid subdirectory, got: {subdir}"
|
||||
},
|
||||
"subdir_not_relative": {
|
||||
"message": "Subdirectory must be relative, got: {subdir}"
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
"download_file": {
|
||||
"name": "Download file",
|
||||
|
94
tests/components/downloader/conftest.py
Normal file
94
tests/components/downloader/conftest.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Provide common fixtures for downloader tests."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from requests_mock import Mocker
|
||||
|
||||
from homeassistant.components.downloader.const import (
|
||||
CONF_DOWNLOAD_DIR,
|
||||
DOMAIN,
|
||||
DOWNLOAD_COMPLETED_EVENT,
|
||||
DOWNLOAD_FAILED_EVENT,
|
||||
)
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_integration(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> MockConfigEntry:
|
||||
"""Set up the downloader integration for testing."""
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return mock_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry(
|
||||
hass: HomeAssistant,
|
||||
download_dir: Path,
|
||||
) -> MockConfigEntry:
|
||||
"""Return a mocked config entry."""
|
||||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={CONF_DOWNLOAD_DIR: str(download_dir)},
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
return config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def download_dir(tmp_path: Path) -> Path:
|
||||
"""Return a download directory."""
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_download_request(
|
||||
requests_mock: Mocker,
|
||||
download_url: str,
|
||||
) -> None:
|
||||
"""Mock the download request."""
|
||||
requests_mock.get(download_url, text="{'one': 1}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def download_url() -> str:
|
||||
"""Return a mock download URL."""
|
||||
return "http://example.com/file.txt"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def download_completed(hass: HomeAssistant) -> asyncio.Event:
|
||||
"""Return an asyncio event to wait for download completion."""
|
||||
download_event = asyncio.Event()
|
||||
|
||||
@callback
|
||||
def download_set(event: Event[dict[str, str]]) -> None:
|
||||
"""Set the event when download is completed."""
|
||||
download_event.set()
|
||||
|
||||
hass.bus.async_listen_once(f"{DOMAIN}_{DOWNLOAD_COMPLETED_EVENT}", download_set)
|
||||
|
||||
return download_event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def download_failed(hass: HomeAssistant) -> asyncio.Event:
|
||||
"""Return an asyncio event to wait for download failure."""
|
||||
download_event = asyncio.Event()
|
||||
|
||||
@callback
|
||||
def download_set(event: Event[dict[str, str]]) -> None:
|
||||
"""Set the event when download has failed."""
|
||||
download_event.set()
|
||||
|
||||
hass.bus.async_listen_once(f"{DOMAIN}_{DOWNLOAD_FAILED_EVENT}", download_set)
|
||||
|
||||
return download_event
|
@@ -1,6 +1,8 @@
|
||||
"""Tests for the downloader component init."""
|
||||
|
||||
from unittest.mock import patch
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.downloader.const import (
|
||||
CONF_DOWNLOAD_DIR,
|
||||
@@ -13,17 +15,57 @@ from homeassistant.core import HomeAssistant
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_initialization(hass: HomeAssistant) -> None:
|
||||
"""Test the initialization of the downloader component."""
|
||||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_DOWNLOAD_DIR: "/test_dir",
|
||||
},
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
with patch("os.path.isdir", return_value=True):
|
||||
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
||||
@pytest.fixture
|
||||
def download_dir(tmp_path: Path, request: pytest.FixtureRequest) -> Path:
|
||||
"""Return a download directory."""
|
||||
if hasattr(request, "param"):
|
||||
return tmp_path / request.param
|
||||
return tmp_path
|
||||
|
||||
|
||||
async def test_config_entry_setup(
|
||||
hass: HomeAssistant, setup_integration: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test config entry setup."""
|
||||
config_entry = setup_integration
|
||||
|
||||
assert hass.services.has_service(DOMAIN, SERVICE_DOWNLOAD_FILE)
|
||||
assert config_entry.state is ConfigEntryState.LOADED
|
||||
|
||||
|
||||
async def test_config_entry_setup_relative_directory(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test config entry setup with a relative download directory."""
|
||||
relative_directory = "downloads"
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
data={**mock_config_entry.data, CONF_DOWNLOAD_DIR: relative_directory},
|
||||
)
|
||||
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
|
||||
# The config entry will fail to set up since the directory does not exist.
|
||||
# This is not relevant for this test.
|
||||
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR
|
||||
assert mock_config_entry.data[CONF_DOWNLOAD_DIR] == hass.config.path(
|
||||
relative_directory
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"download_dir",
|
||||
[
|
||||
"not_existing_path",
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_config_entry_setup_not_existing_directory(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test config entry setup without existing download directory."""
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
|
||||
assert not hass.services.has_service(DOMAIN, SERVICE_DOWNLOAD_FILE)
|
||||
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR
|
||||
|
54
tests/components/downloader/test_services.py
Normal file
54
tests/components/downloader/test_services.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Test downloader services."""
|
||||
|
||||
import asyncio
|
||||
from contextlib import AbstractContextManager, nullcontext as does_not_raise
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.downloader.const import DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ServiceValidationError
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_integration")
|
||||
@pytest.mark.parametrize(
|
||||
("subdir", "expected_result"),
|
||||
[
|
||||
("test", does_not_raise()),
|
||||
("test/path", does_not_raise()),
|
||||
("~test/path", pytest.raises(ServiceValidationError)),
|
||||
("~/../test/path", pytest.raises(ServiceValidationError)),
|
||||
("../test/path", pytest.raises(ServiceValidationError)),
|
||||
(".../test/path", pytest.raises(ServiceValidationError)),
|
||||
("/test/path", pytest.raises(ServiceValidationError)),
|
||||
],
|
||||
)
|
||||
async def test_download_invalid_subdir(
|
||||
hass: HomeAssistant,
|
||||
download_completed: asyncio.Event,
|
||||
download_failed: asyncio.Event,
|
||||
download_url: str,
|
||||
subdir: str,
|
||||
expected_result: AbstractContextManager,
|
||||
) -> None:
|
||||
"""Test service invalid subdirectory."""
|
||||
|
||||
async def call_service() -> None:
|
||||
"""Call the download service."""
|
||||
completed = hass.async_create_task(download_completed.wait())
|
||||
failed = hass.async_create_task(download_failed.wait())
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
"download_file",
|
||||
{
|
||||
"url": download_url,
|
||||
"subdir": subdir,
|
||||
"filename": "file.txt",
|
||||
"overwrite": True,
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
await asyncio.wait((completed, failed), return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
with expected_result:
|
||||
await call_service()
|
Reference in New Issue
Block a user