mirror of
https://github.com/home-assistant/core.git
synced 2026-06-25 08:05:21 +02:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c046c0cde6 | |||
| 028c6aa107 | |||
| ce49db9076 | |||
| 5d5d65ea75 |
@@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.media_source import local_source
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
|
||||
from homeassistant.core import (
|
||||
@@ -34,6 +35,7 @@ from .const import (
|
||||
)
|
||||
from .entity import AITaskEntity
|
||||
from .http import async_setup as async_setup_http
|
||||
from .media_source import async_get_media_source
|
||||
from .task import (
|
||||
GenDataTask,
|
||||
GenDataTaskResult,
|
||||
@@ -88,6 +90,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
|
||||
await hass.data[DATA_PREFERENCES].async_load()
|
||||
async_setup_http(hass)
|
||||
if hass.config.media_dirs:
|
||||
source = await async_get_media_source(hass)
|
||||
hass.http.register_view(local_source.LocalMediaView(hass, source))
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
SERVICE_GENERATE_DATA,
|
||||
|
||||
@@ -2,14 +2,16 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from homeassistant.components.media_source import MediaSource, local_source
|
||||
from homeassistant.components.media_source import local_source
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.singleton import singleton
|
||||
|
||||
from .const import DATA_MEDIA_SOURCE, DOMAIN, IMAGE_DIR
|
||||
|
||||
|
||||
async def async_get_media_source(hass: HomeAssistant) -> MediaSource:
|
||||
@singleton(DATA_MEDIA_SOURCE, async_=True)
|
||||
async def async_get_media_source(hass: HomeAssistant) -> local_source.LocalSource:
|
||||
"""Set up local media source."""
|
||||
media_dirs = list(hass.config.media_dirs.values())
|
||||
|
||||
@@ -20,11 +22,10 @@ async def async_get_media_source(hass: HomeAssistant) -> MediaSource:
|
||||
|
||||
media_dir = Path(media_dirs[0]) / DOMAIN / IMAGE_DIR
|
||||
|
||||
hass.data[DATA_MEDIA_SOURCE] = source = local_source.LocalSource(
|
||||
return local_source.LocalSource(
|
||||
hass,
|
||||
DOMAIN,
|
||||
"AI generated images",
|
||||
{IMAGE_DIR: str(media_dir)},
|
||||
f"/{DOMAIN}",
|
||||
)
|
||||
return source
|
||||
|
||||
@@ -5,17 +5,16 @@ from typing import Protocol
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.integration_platform import (
|
||||
async_process_integration_platforms,
|
||||
)
|
||||
from homeassistant.helpers.integration_platform import LazyIntegrationPlatforms
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from . import http, local_source
|
||||
from .const import (
|
||||
DATA_LOCAL_SOURCE,
|
||||
DATA_MEDIA_SOURCE_PLATFORMS,
|
||||
DOMAIN,
|
||||
MEDIA_CLASS_MAP,
|
||||
MEDIA_MIME_TYPES,
|
||||
MEDIA_SOURCE_DATA,
|
||||
URI_SCHEME,
|
||||
URI_SCHEME_REGEX,
|
||||
)
|
||||
@@ -72,17 +71,18 @@ def generate_media_source_id(domain: str, identifier: str) -> str:
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the media_source component."""
|
||||
hass.data[MEDIA_SOURCE_DATA] = {}
|
||||
hass.data[DATA_MEDIA_SOURCE_PLATFORMS] = LazyIntegrationPlatforms[MediaSource](
|
||||
hass, DOMAIN, _process_media_source_platform
|
||||
)
|
||||
http.async_setup(hass)
|
||||
|
||||
# Local sources support
|
||||
await _process_media_source_platform(hass, DOMAIN, local_source)
|
||||
source = await local_source.async_get_media_source(hass)
|
||||
hass.data[DATA_LOCAL_SOURCE] = source
|
||||
hass.http.register_view(local_source.LocalMediaView(hass, source))
|
||||
hass.http.register_view(local_source.UploadMediaView)
|
||||
websocket_api.async_register_command(hass, local_source.websocket_remove_media)
|
||||
|
||||
await async_process_integration_platforms(
|
||||
hass, DOMAIN, _process_media_source_platform
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@@ -90,9 +90,6 @@ async def _process_media_source_platform(
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
platform: MediaSourceProtocol,
|
||||
) -> None:
|
||||
) -> MediaSource:
|
||||
"""Process a media source platform."""
|
||||
source = await platform.async_get_media_source(hass)
|
||||
hass.data[MEDIA_SOURCE_DATA][domain] = source
|
||||
if isinstance(source, local_source.LocalSource):
|
||||
hass.http.register_view(local_source.LocalMediaView(hass, source))
|
||||
return await platform.async_get_media_source(hass)
|
||||
|
||||
@@ -7,10 +7,15 @@ from homeassistant.components.media_player import MediaClass
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.helpers.integration_platform import LazyIntegrationPlatforms
|
||||
|
||||
from .models import MediaSource
|
||||
|
||||
DOMAIN = "media_source"
|
||||
MEDIA_SOURCE_DATA: HassKey[dict[str, MediaSource]] = HassKey(DOMAIN)
|
||||
DATA_LOCAL_SOURCE: HassKey[MediaSource] = HassKey("media_source_local_source")
|
||||
DATA_MEDIA_SOURCE_PLATFORMS: HassKey[LazyIntegrationPlatforms[MediaSource]] = HassKey(
|
||||
"media_source_platforms"
|
||||
)
|
||||
MEDIA_MIME_TYPES = ("audio", "video", "image")
|
||||
MEDIA_CLASS_MAP = {
|
||||
"audio": MediaClass.MUSIC,
|
||||
|
||||
@@ -3,17 +3,23 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from homeassistant.components.media_player import BrowseError, BrowseMedia
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.frame import report_usage
|
||||
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
|
||||
|
||||
from .const import DOMAIN, MEDIA_SOURCE_DATA
|
||||
from .const import DOMAIN
|
||||
from .error import UnknownMediaSource, Unresolvable
|
||||
from .models import BrowseMediaSource, MediaSourceItem, PlayMedia, RootBrowseMediaSource
|
||||
from .models import (
|
||||
BrowseMediaSource,
|
||||
MediaSourceItem,
|
||||
PlayMedia,
|
||||
RootBrowseMediaSource,
|
||||
_async_get_media_source,
|
||||
_async_get_media_sources,
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def _get_media_item(
|
||||
async def _get_media_item(
|
||||
hass: HomeAssistant, media_content_id: str | None, target_media_player: str | None
|
||||
) -> MediaSourceItem:
|
||||
"""Return media item."""
|
||||
@@ -21,10 +27,14 @@ def _get_media_item(
|
||||
item = MediaSourceItem.from_uri(hass, media_content_id, target_media_player)
|
||||
else:
|
||||
# We default to our own domain if its only one registered
|
||||
domain = None if len(hass.data[MEDIA_SOURCE_DATA]) > 1 else DOMAIN
|
||||
sources = await _async_get_media_sources(hass)
|
||||
domain = None if len(sources) > 1 else DOMAIN
|
||||
return MediaSourceItem(hass, domain, "", target_media_player)
|
||||
|
||||
if item.domain is not None and item.domain not in hass.data[MEDIA_SOURCE_DATA]:
|
||||
if (
|
||||
item.domain is not None
|
||||
and await _async_get_media_source(hass, item.domain) is None
|
||||
):
|
||||
raise UnknownMediaSource(
|
||||
translation_domain=DOMAIN,
|
||||
translation_key="unknown_media_source",
|
||||
@@ -41,11 +51,12 @@ async def async_browse_media(
|
||||
content_filter: Callable[[BrowseMedia], bool] | None = None,
|
||||
) -> BrowseMediaSource | RootBrowseMediaSource:
|
||||
"""Return media player browse media results."""
|
||||
if DOMAIN not in hass.data:
|
||||
if DOMAIN not in hass.config.top_level_components:
|
||||
raise BrowseError("Media Source not loaded")
|
||||
|
||||
try:
|
||||
item = await _get_media_item(hass, media_content_id, None).async_browse()
|
||||
media_item = await _get_media_item(hass, media_content_id, None)
|
||||
item = await media_item.async_browse()
|
||||
except ValueError as err:
|
||||
raise BrowseError(
|
||||
translation_domain=DOMAIN,
|
||||
@@ -73,7 +84,7 @@ async def async_resolve_media(
|
||||
target_media_player: str | None | UndefinedType = UNDEFINED,
|
||||
) -> PlayMedia:
|
||||
"""Get info to play media."""
|
||||
if DOMAIN not in hass.data:
|
||||
if DOMAIN not in hass.config.top_level_components:
|
||||
raise Unresolvable("Media Source not loaded")
|
||||
|
||||
if target_media_player is UNDEFINED:
|
||||
@@ -84,7 +95,7 @@ async def async_resolve_media(
|
||||
target_media_player = None
|
||||
|
||||
try:
|
||||
item = _get_media_item(hass, media_content_id, target_media_player)
|
||||
item = await _get_media_item(hass, media_content_id, target_media_player)
|
||||
except ValueError as err:
|
||||
raise Unresolvable(
|
||||
translation_domain=DOMAIN,
|
||||
|
||||
@@ -18,7 +18,7 @@ from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
|
||||
|
||||
from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES, MEDIA_SOURCE_DATA
|
||||
from .const import DATA_LOCAL_SOURCE, DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES
|
||||
from .error import Unresolvable
|
||||
from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia
|
||||
|
||||
@@ -369,7 +369,7 @@ class UploadMediaView(http.HomeAssistantView):
|
||||
if target_folder.domain != DOMAIN:
|
||||
raise web.HTTPBadRequest
|
||||
|
||||
source = cast(LocalSource, hass.data[MEDIA_SOURCE_DATA][target_folder.domain])
|
||||
source = cast(LocalSource, hass.data[DATA_LOCAL_SOURCE])
|
||||
try:
|
||||
uploaded_media_source_id = await source.async_upload_media(
|
||||
target_folder, data["file"]
|
||||
@@ -414,7 +414,7 @@ async def websocket_remove_media(
|
||||
)
|
||||
return
|
||||
|
||||
source = cast(LocalSource, hass.data[MEDIA_SOURCE_DATA][item.domain])
|
||||
source = cast(LocalSource, hass.data[DATA_LOCAL_SOURCE])
|
||||
|
||||
try:
|
||||
await source.async_delete_media(item)
|
||||
|
||||
@@ -4,15 +4,37 @@ from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from homeassistant.components.media_player import BrowseMedia, MediaClass, MediaType
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.translation import async_get_cached_translations
|
||||
|
||||
from .const import MEDIA_SOURCE_DATA, URI_SCHEME, URI_SCHEME_REGEX
|
||||
from .const import (
|
||||
DATA_LOCAL_SOURCE,
|
||||
DATA_MEDIA_SOURCE_PLATFORMS,
|
||||
DOMAIN,
|
||||
URI_SCHEME,
|
||||
URI_SCHEME_REGEX,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
async def _async_get_media_sources(hass: HomeAssistant) -> dict[str, MediaSource]:
|
||||
"""Return all media sources, loading integration platforms on demand."""
|
||||
sources: dict[str, MediaSource] = {DOMAIN: hass.data[DATA_LOCAL_SOURCE]}
|
||||
sources.update(await hass.data[DATA_MEDIA_SOURCE_PLATFORMS].async_get_platforms())
|
||||
return sources
|
||||
|
||||
|
||||
async def _async_get_media_source(
|
||||
hass: HomeAssistant, domain: str
|
||||
) -> MediaSource | None:
|
||||
"""Return the media source for a domain, loading it on demand."""
|
||||
if domain == DOMAIN:
|
||||
return hass.data[DATA_LOCAL_SOURCE]
|
||||
return await hass.data[DATA_MEDIA_SOURCE_PLATFORMS].async_get_platform(domain)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PlayMedia:
|
||||
"""Represents a playable media."""
|
||||
@@ -81,6 +103,7 @@ class MediaSourceItem:
|
||||
can_expand=True,
|
||||
children_media_class=MediaClass.APP,
|
||||
)
|
||||
sources = await _async_get_media_sources(self.hass)
|
||||
base.children = sorted(
|
||||
(
|
||||
BrowseMediaSource(
|
||||
@@ -93,24 +116,28 @@ class MediaSourceItem:
|
||||
can_play=False,
|
||||
can_expand=True,
|
||||
)
|
||||
for source in self.hass.data[MEDIA_SOURCE_DATA].values()
|
||||
for source in sources.values()
|
||||
),
|
||||
key=lambda item: item.title,
|
||||
)
|
||||
return base
|
||||
|
||||
return await self.async_media_source().async_browse_media(self)
|
||||
source = await self._async_media_source()
|
||||
return await source.async_browse_media(self)
|
||||
|
||||
async def async_resolve(self) -> PlayMedia:
|
||||
"""Resolve to playable item."""
|
||||
return await self.async_media_source().async_resolve_media(self)
|
||||
source = await self._async_media_source()
|
||||
return await source.async_resolve_media(self)
|
||||
|
||||
@callback
|
||||
def async_media_source(self) -> MediaSource:
|
||||
async def _async_media_source(self) -> MediaSource:
|
||||
"""Return media source that owns this item."""
|
||||
if TYPE_CHECKING:
|
||||
assert self.domain is not None
|
||||
return self.hass.data[MEDIA_SOURCE_DATA][self.domain]
|
||||
# Existence is validated by _get_media_item before browse/resolve.
|
||||
source = await _async_get_media_source(self.hass, self.domain)
|
||||
assert source is not None
|
||||
return source
|
||||
|
||||
@classmethod
|
||||
def from_uri(
|
||||
|
||||
@@ -265,7 +265,7 @@ async def _async_process_integration_platforms(
|
||||
|
||||
|
||||
# Any = platform.
|
||||
type ProcessPlatform[_R] = Callable[[HomeAssistant, str, Any], _R]
|
||||
type ProcessPlatform[_R] = Callable[[HomeAssistant, str, Any], _R | Awaitable[_R]]
|
||||
|
||||
|
||||
class LazyIntegrationPlatforms[_R]:
|
||||
@@ -276,6 +276,8 @@ class LazyIntegrationPlatforms[_R]:
|
||||
this only imports and processes the platform for an integration the first
|
||||
time it is requested, and only for integrations that are loaded.
|
||||
|
||||
The process callback may be a coroutine function; its result is awaited.
|
||||
|
||||
The platform is intentionally not registered for preloading, since for a
|
||||
rarely used platform that would import it for every integration during
|
||||
loading, defeating the point of loading it lazily.
|
||||
@@ -357,7 +359,10 @@ class LazyIntegrationPlatforms[_R]:
|
||||
result: _R | None = None
|
||||
if platform is not None:
|
||||
try:
|
||||
result = self._process_platform(self._hass, domain, platform)
|
||||
processed = self._process_platform(self._hass, domain, platform)
|
||||
if isinstance(processed, Awaitable):
|
||||
processed = await processed
|
||||
result = processed
|
||||
except Exception:
|
||||
_LOGGER.exception(
|
||||
"Error processing %s platform for %s",
|
||||
|
||||
@@ -26,6 +26,9 @@ async def test_local_media_source(hass: HomeAssistant, init_components: None) ->
|
||||
)
|
||||
assert source.url_prefix == "/ai_task"
|
||||
|
||||
|
||||
async def test_media_source_no_media_dirs(hass: HomeAssistant) -> None:
|
||||
"""Test an error is raised when no media directories are configured."""
|
||||
hass.config.media_dirs = {}
|
||||
|
||||
with pytest.raises(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Test media source helpers."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -120,10 +120,10 @@ async def test_async_unresolve_media(hass: HomeAssistant) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def test_browse_resolve_without_setup() -> None:
|
||||
async def test_browse_resolve_without_setup(hass: HomeAssistant) -> None:
|
||||
"""Test browse and resolve work without being setup."""
|
||||
with pytest.raises(BrowseError):
|
||||
await media_source.async_browse_media(Mock(data={}), None)
|
||||
await media_source.async_browse_media(hass, None)
|
||||
|
||||
with pytest.raises(media_source.Unresolvable):
|
||||
await media_source.async_resolve_media(Mock(data={}), None, None)
|
||||
await media_source.async_resolve_media(hass, None, None)
|
||||
|
||||
@@ -471,3 +471,25 @@ async def test_lazy_integration_platforms_concurrent(hass: HomeAssistant) -> Non
|
||||
assert results == [loaded_platform, loaded_platform]
|
||||
# The platform was imported and processed exactly once.
|
||||
assert processed == ["loaded"]
|
||||
|
||||
|
||||
async def test_lazy_integration_platforms_async_process(hass: HomeAssistant) -> None:
|
||||
"""Test a coroutine process callback is awaited and its result cached."""
|
||||
loaded_platform = Mock()
|
||||
mock_platform(hass, "loaded.platform_to_check", loaded_platform)
|
||||
hass.config.components.add("loaded")
|
||||
|
||||
processed: list[str] = []
|
||||
|
||||
async def _process_platform(hass: HomeAssistant, domain: str, platform: Any) -> Any:
|
||||
processed.append(domain)
|
||||
return platform
|
||||
|
||||
platforms = LazyIntegrationPlatforms(hass, "platform_to_check", _process_platform)
|
||||
|
||||
assert await platforms.async_get_platform("loaded") is loaded_platform
|
||||
assert processed == ["loaded"]
|
||||
|
||||
# The awaited result is cached, so a subsequent request does not reprocess.
|
||||
assert await platforms.async_get_platform("loaded") is loaded_platform
|
||||
assert processed == ["loaded"]
|
||||
|
||||
Reference in New Issue
Block a user