Compare commits

...

4 Commits

Author SHA1 Message Date
Paulus Schoutsen c046c0cde6 Set up ai_task media source eagerly and decouple view registration
The media_source component now loads integration platforms lazily, so
ai_task's async_get_media_source is no longer called at startup. ai_task
relied on that side effect to populate DATA_MEDIA_SOURCE (read directly
during image generation) and to register the LocalMediaView that serves
generated images, causing a KeyError on image generation.

Set up ai_task's local source eagerly in its own async_setup, and make
async_get_media_source a singleton so the eager setup and lazy media
browser processing share one instance.

Move the built-in local source's LocalMediaView registration into
media_source's async_setup so view registration is no longer a side
effect of the (now lazy) platform processing path.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-21 12:46:15 -04:00
Paulus Schoutsen 028c6aa107 Mark media source method as internal 2026-06-21 00:50:00 -04:00
Paulus Schoutsen ce49db9076 Simplify local source handling 2026-06-21 00:38:43 -04:00
Paulus Schoutsen 5d5d65ea75 Lazily load media_source integration platforms
Media source platforms are only needed when browsing or resolving
media. Load them on demand via LazyIntegrationPlatforms instead of
importing and processing every integration's media source platform when
the media_source component sets up. The built-in local source stays
eagerly loaded. Per-domain resolve/browse imports only the requested
source; browsing the root still loads them all.

Since media source platform processing is a coroutine
(async_get_media_source), LazyIntegrationPlatforms now awaits the
process callback result when it returns an awaitable.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-20 12:48:48 -04:00
11 changed files with 123 additions and 47 deletions
@@ -5,6 +5,7 @@ from typing import Any
import voluptuous as vol
from homeassistant.components.media_source import local_source
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
from homeassistant.core import (
@@ -34,6 +35,7 @@ from .const import (
)
from .entity import AITaskEntity
from .http import async_setup as async_setup_http
from .media_source import async_get_media_source
from .task import (
GenDataTask,
GenDataTaskResult,
@@ -88,6 +90,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
await hass.data[DATA_PREFERENCES].async_load()
async_setup_http(hass)
if hass.config.media_dirs:
source = await async_get_media_source(hass)
hass.http.register_view(local_source.LocalMediaView(hass, source))
hass.services.async_register(
DOMAIN,
SERVICE_GENERATE_DATA,
@@ -2,14 +2,16 @@
from pathlib import Path
from homeassistant.components.media_source import MediaSource, local_source
from homeassistant.components.media_source import local_source
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.singleton import singleton
from .const import DATA_MEDIA_SOURCE, DOMAIN, IMAGE_DIR
async def async_get_media_source(hass: HomeAssistant) -> MediaSource:
@singleton(DATA_MEDIA_SOURCE, async_=True)
async def async_get_media_source(hass: HomeAssistant) -> local_source.LocalSource:
"""Set up local media source."""
media_dirs = list(hass.config.media_dirs.values())
@@ -20,11 +22,10 @@ async def async_get_media_source(hass: HomeAssistant) -> MediaSource:
media_dir = Path(media_dirs[0]) / DOMAIN / IMAGE_DIR
hass.data[DATA_MEDIA_SOURCE] = source = local_source.LocalSource(
return local_source.LocalSource(
hass,
DOMAIN,
"AI generated images",
{IMAGE_DIR: str(media_dir)},
f"/{DOMAIN}",
)
return source
@@ -5,17 +5,16 @@ from typing import Protocol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.integration_platform import (
async_process_integration_platforms,
)
from homeassistant.helpers.integration_platform import LazyIntegrationPlatforms
from homeassistant.helpers.typing import ConfigType
from . import http, local_source
from .const import (
DATA_LOCAL_SOURCE,
DATA_MEDIA_SOURCE_PLATFORMS,
DOMAIN,
MEDIA_CLASS_MAP,
MEDIA_MIME_TYPES,
MEDIA_SOURCE_DATA,
URI_SCHEME,
URI_SCHEME_REGEX,
)
@@ -72,17 +71,18 @@ def generate_media_source_id(domain: str, identifier: str) -> str:
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the media_source component."""
hass.data[MEDIA_SOURCE_DATA] = {}
hass.data[DATA_MEDIA_SOURCE_PLATFORMS] = LazyIntegrationPlatforms[MediaSource](
hass, DOMAIN, _process_media_source_platform
)
http.async_setup(hass)
# Local sources support
await _process_media_source_platform(hass, DOMAIN, local_source)
source = await local_source.async_get_media_source(hass)
hass.data[DATA_LOCAL_SOURCE] = source
hass.http.register_view(local_source.LocalMediaView(hass, source))
hass.http.register_view(local_source.UploadMediaView)
websocket_api.async_register_command(hass, local_source.websocket_remove_media)
await async_process_integration_platforms(
hass, DOMAIN, _process_media_source_platform
)
return True
@@ -90,9 +90,6 @@ async def _process_media_source_platform(
hass: HomeAssistant,
domain: str,
platform: MediaSourceProtocol,
) -> None:
) -> MediaSource:
"""Process a media source platform."""
source = await platform.async_get_media_source(hass)
hass.data[MEDIA_SOURCE_DATA][domain] = source
if isinstance(source, local_source.LocalSource):
hass.http.register_view(local_source.LocalMediaView(hass, source))
return await platform.async_get_media_source(hass)
@@ -7,10 +7,15 @@ from homeassistant.components.media_player import MediaClass
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from homeassistant.helpers.integration_platform import LazyIntegrationPlatforms
from .models import MediaSource
DOMAIN = "media_source"
MEDIA_SOURCE_DATA: HassKey[dict[str, MediaSource]] = HassKey(DOMAIN)
DATA_LOCAL_SOURCE: HassKey[MediaSource] = HassKey("media_source_local_source")
DATA_MEDIA_SOURCE_PLATFORMS: HassKey[LazyIntegrationPlatforms[MediaSource]] = HassKey(
"media_source_platforms"
)
MEDIA_MIME_TYPES = ("audio", "video", "image")
MEDIA_CLASS_MAP = {
"audio": MediaClass.MUSIC,
+22 -11
View File
@@ -3,17 +3,23 @@
from collections.abc import Callable
from homeassistant.components.media_player import BrowseError, BrowseMedia
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import HomeAssistant
from homeassistant.helpers.frame import report_usage
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
from .const import DOMAIN, MEDIA_SOURCE_DATA
from .const import DOMAIN
from .error import UnknownMediaSource, Unresolvable
from .models import BrowseMediaSource, MediaSourceItem, PlayMedia, RootBrowseMediaSource
from .models import (
BrowseMediaSource,
MediaSourceItem,
PlayMedia,
RootBrowseMediaSource,
_async_get_media_source,
_async_get_media_sources,
)
@callback
def _get_media_item(
async def _get_media_item(
hass: HomeAssistant, media_content_id: str | None, target_media_player: str | None
) -> MediaSourceItem:
"""Return media item."""
@@ -21,10 +27,14 @@ def _get_media_item(
item = MediaSourceItem.from_uri(hass, media_content_id, target_media_player)
else:
# We default to our own domain if its only one registered
domain = None if len(hass.data[MEDIA_SOURCE_DATA]) > 1 else DOMAIN
sources = await _async_get_media_sources(hass)
domain = None if len(sources) > 1 else DOMAIN
return MediaSourceItem(hass, domain, "", target_media_player)
if item.domain is not None and item.domain not in hass.data[MEDIA_SOURCE_DATA]:
if (
item.domain is not None
and await _async_get_media_source(hass, item.domain) is None
):
raise UnknownMediaSource(
translation_domain=DOMAIN,
translation_key="unknown_media_source",
@@ -41,11 +51,12 @@ async def async_browse_media(
content_filter: Callable[[BrowseMedia], bool] | None = None,
) -> BrowseMediaSource | RootBrowseMediaSource:
"""Return media player browse media results."""
if DOMAIN not in hass.data:
if DOMAIN not in hass.config.top_level_components:
raise BrowseError("Media Source not loaded")
try:
item = await _get_media_item(hass, media_content_id, None).async_browse()
media_item = await _get_media_item(hass, media_content_id, None)
item = await media_item.async_browse()
except ValueError as err:
raise BrowseError(
translation_domain=DOMAIN,
@@ -73,7 +84,7 @@ async def async_resolve_media(
target_media_player: str | None | UndefinedType = UNDEFINED,
) -> PlayMedia:
"""Get info to play media."""
if DOMAIN not in hass.data:
if DOMAIN not in hass.config.top_level_components:
raise Unresolvable("Media Source not loaded")
if target_media_player is UNDEFINED:
@@ -84,7 +95,7 @@ async def async_resolve_media(
target_media_player = None
try:
item = _get_media_item(hass, media_content_id, target_media_player)
item = await _get_media_item(hass, media_content_id, target_media_player)
except ValueError as err:
raise Unresolvable(
translation_domain=DOMAIN,
@@ -18,7 +18,7 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES, MEDIA_SOURCE_DATA
from .const import DATA_LOCAL_SOURCE, DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES
from .error import Unresolvable
from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia
@@ -369,7 +369,7 @@ class UploadMediaView(http.HomeAssistantView):
if target_folder.domain != DOMAIN:
raise web.HTTPBadRequest
source = cast(LocalSource, hass.data[MEDIA_SOURCE_DATA][target_folder.domain])
source = cast(LocalSource, hass.data[DATA_LOCAL_SOURCE])
try:
uploaded_media_source_id = await source.async_upload_media(
target_folder, data["file"]
@@ -414,7 +414,7 @@ async def websocket_remove_media(
)
return
source = cast(LocalSource, hass.data[MEDIA_SOURCE_DATA][item.domain])
source = cast(LocalSource, hass.data[DATA_LOCAL_SOURCE])
try:
await source.async_delete_media(item)
@@ -4,15 +4,37 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from homeassistant.components.media_player import BrowseMedia, MediaClass, MediaType
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import HomeAssistant
from homeassistant.helpers.translation import async_get_cached_translations
from .const import MEDIA_SOURCE_DATA, URI_SCHEME, URI_SCHEME_REGEX
from .const import (
DATA_LOCAL_SOURCE,
DATA_MEDIA_SOURCE_PLATFORMS,
DOMAIN,
URI_SCHEME,
URI_SCHEME_REGEX,
)
if TYPE_CHECKING:
from pathlib import Path
async def _async_get_media_sources(hass: HomeAssistant) -> dict[str, MediaSource]:
"""Return all media sources, loading integration platforms on demand."""
sources: dict[str, MediaSource] = {DOMAIN: hass.data[DATA_LOCAL_SOURCE]}
sources.update(await hass.data[DATA_MEDIA_SOURCE_PLATFORMS].async_get_platforms())
return sources
async def _async_get_media_source(
hass: HomeAssistant, domain: str
) -> MediaSource | None:
"""Return the media source for a domain, loading it on demand."""
if domain == DOMAIN:
return hass.data[DATA_LOCAL_SOURCE]
return await hass.data[DATA_MEDIA_SOURCE_PLATFORMS].async_get_platform(domain)
@dataclass(slots=True)
class PlayMedia:
"""Represents a playable media."""
@@ -81,6 +103,7 @@ class MediaSourceItem:
can_expand=True,
children_media_class=MediaClass.APP,
)
sources = await _async_get_media_sources(self.hass)
base.children = sorted(
(
BrowseMediaSource(
@@ -93,24 +116,28 @@ class MediaSourceItem:
can_play=False,
can_expand=True,
)
for source in self.hass.data[MEDIA_SOURCE_DATA].values()
for source in sources.values()
),
key=lambda item: item.title,
)
return base
return await self.async_media_source().async_browse_media(self)
source = await self._async_media_source()
return await source.async_browse_media(self)
async def async_resolve(self) -> PlayMedia:
"""Resolve to playable item."""
return await self.async_media_source().async_resolve_media(self)
source = await self._async_media_source()
return await source.async_resolve_media(self)
@callback
def async_media_source(self) -> MediaSource:
async def _async_media_source(self) -> MediaSource:
"""Return media source that owns this item."""
if TYPE_CHECKING:
assert self.domain is not None
return self.hass.data[MEDIA_SOURCE_DATA][self.domain]
# Existence is validated by _get_media_item before browse/resolve.
source = await _async_get_media_source(self.hass, self.domain)
assert source is not None
return source
@classmethod
def from_uri(
@@ -265,7 +265,7 @@ async def _async_process_integration_platforms(
# Any = platform.
type ProcessPlatform[_R] = Callable[[HomeAssistant, str, Any], _R]
type ProcessPlatform[_R] = Callable[[HomeAssistant, str, Any], _R | Awaitable[_R]]
class LazyIntegrationPlatforms[_R]:
@@ -276,6 +276,8 @@ class LazyIntegrationPlatforms[_R]:
this only imports and processes the platform for an integration the first
time it is requested, and only for integrations that are loaded.
The process callback may be a coroutine function; its result is awaited.
The platform is intentionally not registered for preloading, since for a
rarely used platform that would import it for every integration during
loading, defeating the point of loading it lazily.
@@ -357,7 +359,10 @@ class LazyIntegrationPlatforms[_R]:
result: _R | None = None
if platform is not None:
try:
result = self._process_platform(self._hass, domain, platform)
processed = self._process_platform(self._hass, domain, platform)
if isinstance(processed, Awaitable):
processed = await processed
result = processed
except Exception:
_LOGGER.exception(
"Error processing %s platform for %s",
@@ -26,6 +26,9 @@ async def test_local_media_source(hass: HomeAssistant, init_components: None) ->
)
assert source.url_prefix == "/ai_task"
async def test_media_source_no_media_dirs(hass: HomeAssistant) -> None:
"""Test an error is raised when no media directories are configured."""
hass.config.media_dirs = {}
with pytest.raises(
+4 -4
View File
@@ -1,6 +1,6 @@
"""Test media source helpers."""
from unittest.mock import Mock, patch
from unittest.mock import patch
import pytest
@@ -120,10 +120,10 @@ async def test_async_unresolve_media(hass: HomeAssistant) -> None:
)
async def test_browse_resolve_without_setup() -> None:
async def test_browse_resolve_without_setup(hass: HomeAssistant) -> None:
"""Test browse and resolve work without being setup."""
with pytest.raises(BrowseError):
await media_source.async_browse_media(Mock(data={}), None)
await media_source.async_browse_media(hass, None)
with pytest.raises(media_source.Unresolvable):
await media_source.async_resolve_media(Mock(data={}), None, None)
await media_source.async_resolve_media(hass, None, None)
@@ -471,3 +471,25 @@ async def test_lazy_integration_platforms_concurrent(hass: HomeAssistant) -> Non
assert results == [loaded_platform, loaded_platform]
# The platform was imported and processed exactly once.
assert processed == ["loaded"]
async def test_lazy_integration_platforms_async_process(hass: HomeAssistant) -> None:
"""Test a coroutine process callback is awaited and its result cached."""
loaded_platform = Mock()
mock_platform(hass, "loaded.platform_to_check", loaded_platform)
hass.config.components.add("loaded")
processed: list[str] = []
async def _process_platform(hass: HomeAssistant, domain: str, platform: Any) -> Any:
processed.append(domain)
return platform
platforms = LazyIntegrationPlatforms(hass, "platform_to_check", _process_platform)
assert await platforms.async_get_platform("loaded") is loaded_platform
assert processed == ["loaded"]
# The awaited result is cached, so a subsequent request does not reprocess.
assert await platforms.async_get_platform("loaded") is loaded_platform
assert processed == ["loaded"]