Compare commits

...

3 Commits

Author SHA1 Message Date
Claude 2303dc53e2 Allow media source platforms to defer registration via None
The MediaSourceProtocol now accepts None from async_get_media_source to
defer registration, and async_register_media_source is exposed so a
source can be wired up later. AI Task uses this to stay hidden until
the first image upload creates the storage folder.
2026-04-29 16:16:44 +00:00
Claude c5d7689674 Lazy register AI Task media source after first image upload 2026-04-29 14:48:25 +00:00
Claude dde9eeaf6d Hide AI Task media source until an image has been generated 2026-04-29 14:38:04 +00:00
5 changed files with 74 additions and 12 deletions
@@ -11,8 +11,14 @@ from homeassistant.exceptions import HomeAssistantError
from .const import DATA_MEDIA_SOURCE, DOMAIN, IMAGE_DIR
async def async_get_media_source(hass: HomeAssistant) -> MediaSource:
"""Set up local media source."""
async def async_get_media_source(hass: HomeAssistant) -> MediaSource | None:
"""Set up local media source.
The source is only exposed once an image has been generated. The local
source object is always created so that image generation can use it to
upload, and ``async_generate_image`` registers the source with media_source
after the first upload via :func:`media_source.async_register_media_source`.
"""
media_dirs = list(hass.config.media_dirs.values())
if not media_dirs:
@@ -29,4 +35,8 @@ async def async_get_media_source(hass: HomeAssistant) -> MediaSource:
{IMAGE_DIR: str(media_dir)},
f"/{DOMAIN}",
)
if not await hass.async_add_executor_job(media_dir.exists):
return None
return source
+6
View File
@@ -231,6 +231,12 @@ async def async_generate_image(
target_folder, image_file
)
# The folder is created on first upload, so register the source now that
# there is content. async_get_media_source defers registration until the
# folder exists, and the helper is a no-op if the source is already
# registered.
media_source.async_register_media_source(hass, source)
item = media_source.MediaSourceItem.from_uri(
hass, service_result["media_source_id"], None
)
@@ -5,7 +5,7 @@ from __future__ import annotations
from typing import Protocol
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.integration_platform import (
async_process_integration_platforms,
@@ -36,6 +36,7 @@ __all__ = [
"PlayMedia",
"Unresolvable",
"async_browse_media",
"async_register_media_source",
"async_resolve_media",
"generate_media_source_id",
"is_media_source_id",
@@ -48,8 +49,12 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
class MediaSourceProtocol(Protocol):
"""Define the format of media_source platforms."""
async def async_get_media_source(self, hass: HomeAssistant) -> MediaSource:
"""Set up media source."""
async def async_get_media_source(self, hass: HomeAssistant) -> MediaSource | None:
"""Set up media source.
Return ``None`` to defer registration; the integration can register
the source later via :func:`async_register_media_source`.
"""
def is_media_source_id(media_content_id: str) -> bool:
@@ -65,6 +70,22 @@ def generate_media_source_id(domain: str, identifier: str) -> str:
return uri
@callback
def async_register_media_source(hass: HomeAssistant, source: MediaSource) -> None:
"""Register a media source.
Use this to register a source after integration setup, e.g. once content
becomes available. Calling this with a source whose domain is already
registered is a no-op.
"""
sources = hass.data[MEDIA_SOURCE_DATA]
if source.domain in sources:
return
sources[source.domain] = source
if isinstance(source, local_source.LocalSource):
hass.http.register_view(local_source.LocalMediaView(hass, source))
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the media_source component."""
hass.data[MEDIA_SOURCE_DATA] = {}
@@ -88,6 +109,6 @@ async def _process_media_source_platform(
) -> None:
"""Process a media source platform."""
source = await platform.async_get_media_source(hass)
hass.data[MEDIA_SOURCE_DATA][domain] = source
if isinstance(source, local_source.LocalSource):
hass.http.register_view(local_source.LocalMediaView(hass, source))
if source is None:
return
async_register_media_source(hass, source)
+22 -3
View File
@@ -1,20 +1,29 @@
"""Test ai_task media source."""
from unittest.mock import patch
import pytest
from homeassistant.components import media_source
from homeassistant.components.ai_task.const import DATA_MEDIA_SOURCE
from homeassistant.components.ai_task.media_source import async_get_media_source
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
async def test_local_media_source(hass: HomeAssistant, init_components: None) -> None:
"""Test that the image media source is created."""
"""Test the image media source is only registered once an image is generated."""
# The image folder does not exist yet, so the media source should not be
# listed as a top-level media source.
item = await media_source.async_browse_media(hass, "media-source://")
assert not any(c.title == "AI generated images" for c in item.children)
assert any(c.title == "AI generated images" for c in item.children)
# async_get_media_source returns None to defer registration.
assert await async_get_media_source(hass) is None
source = await async_get_media_source(hass)
# The local source is still configured internally so image generation can
# use it to upload new images.
source = hass.data[DATA_MEDIA_SOURCE]
assert isinstance(source, media_source.local_source.LocalSource)
assert source.name == "AI generated images"
assert source.domain == "ai_task"
@@ -26,6 +35,16 @@ async def test_local_media_source(hass: HomeAssistant, init_components: None) ->
)
assert source.url_prefix == "/ai_task"
# Once an image has been generated and the folder exists, the source is
# returned.
with patch(
"homeassistant.components.ai_task.media_source.Path.exists",
return_value=True,
):
result = await async_get_media_source(hass)
assert result is hass.data[DATA_MEDIA_SOURCE]
assert isinstance(result, media_source.local_source.LocalSource)
hass.config.media_dirs = {}
with pytest.raises(
+7 -1
View File
@@ -13,9 +13,10 @@ from homeassistant.components.ai_task import (
async_generate_data,
async_generate_image,
)
from homeassistant.components.ai_task.const import DATA_MEDIA_SOURCE
from homeassistant.components.ai_task.const import DATA_MEDIA_SOURCE, DOMAIN
from homeassistant.components.camera import Image
from homeassistant.components.conversation import async_get_chat_log
from homeassistant.components.media_source.const import MEDIA_SOURCE_DATA
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@@ -400,6 +401,9 @@ async def test_generate_image(
assert state is not None
assert state.state == STATE_UNKNOWN
# Until an image is generated the media source is not registered.
assert DOMAIN not in hass.data[MEDIA_SOURCE_DATA]
with patch.object(
hass.data[DATA_MEDIA_SOURCE],
"async_upload_media",
@@ -412,6 +416,8 @@ async def test_generate_image(
instructions="Test prompt",
)
mock_upload_media.assert_called_once()
# The first upload registers the source so URLs and browse listings work.
assert hass.data[MEDIA_SOURCE_DATA][DOMAIN] is hass.data[DATA_MEDIA_SOURCE]
assert "image_data" not in result
assert (
result["media_source_id"]