diff --git a/homeassistant/components/ai_task/__init__.py b/homeassistant/components/ai_task/__init__.py index a16e11c05d7..adae039ea5c 100644 --- a/homeassistant/components/ai_task/__init__.py +++ b/homeassistant/components/ai_task/__init__.py @@ -3,8 +3,10 @@ import logging from typing import Any +from aiohttp import web import voluptuous as vol +from homeassistant.components.http import KEY_HASS, HomeAssistantView from homeassistant.config_entries import ConfigEntry from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR from homeassistant.core import ( @@ -26,14 +28,24 @@ from .const import ( ATTR_STRUCTURE, ATTR_TASK_NAME, DATA_COMPONENT, + DATA_IMAGES, DATA_PREFERENCES, DOMAIN, SERVICE_GENERATE_DATA, + SERVICE_GENERATE_IMAGE, AITaskEntityFeature, ) from .entity import AITaskEntity from .http import async_setup as async_setup_http -from .task import GenDataTask, GenDataTaskResult, async_generate_data +from .task import ( + GenDataTask, + GenDataTaskResult, + GenImageTask, + GenImageTaskResult, + ImageData, + async_generate_data, + async_generate_image, +) __all__ = [ "DOMAIN", @@ -41,7 +53,11 @@ __all__ = [ "AITaskEntityFeature", "GenDataTask", "GenDataTaskResult", + "GenImageTask", + "GenImageTaskResult", + "ImageData", "async_generate_data", + "async_generate_image", "async_setup", "async_setup_entry", "async_unload_entry", @@ -78,8 +94,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass) hass.data[DATA_COMPONENT] = entity_component hass.data[DATA_PREFERENCES] = AITaskPreferences(hass) + hass.data[DATA_IMAGES] = {} await hass.data[DATA_PREFERENCES].async_load() async_setup_http(hass) + hass.http.register_view(ImageView) hass.services.async_register( DOMAIN, SERVICE_GENERATE_DATA, @@ -101,6 +119,23 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: supports_response=SupportsResponse.ONLY, job_type=HassJobType.Coroutinefunction, ) + hass.services.async_register( + DOMAIN, + SERVICE_GENERATE_IMAGE, + async_service_generate_image, + schema=vol.Schema( + { + vol.Required(ATTR_TASK_NAME): cv.string, + vol.Required(ATTR_ENTITY_ID): cv.entity_id, + vol.Required(ATTR_INSTRUCTIONS): cv.string, + vol.Optional(ATTR_ATTACHMENTS): vol.All( + cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})] + ), + } + ), + supports_response=SupportsResponse.ONLY, + job_type=HassJobType.Coroutinefunction, + ) return True @@ -115,11 +150,16 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_service_generate_data(call: ServiceCall) -> ServiceResponse: - """Run the run task service.""" + """Run the data task service.""" result = await async_generate_data(hass=call.hass, **call.data) return result.as_dict() +async def async_service_generate_image(call: ServiceCall) -> ServiceResponse: + """Run the image task service.""" + return await async_generate_image(hass=call.hass, **call.data) + + class AITaskPreferences: """AI Task preferences.""" @@ -164,3 +204,29 @@ class AITaskPreferences: def as_dict(self) -> dict[str, str | None]: """Get the current preferences.""" return {key: getattr(self, key) for key in self.KEYS} + + +class ImageView(HomeAssistantView): + """View to generated images.""" + + url = f"/api/{DOMAIN}/images/{{filename}}" + name = f"api:{DOMAIN}/images" + requires_auth = False + + async def get( + self, + request: web.Request, + filename: str, + ) -> web.Response: + """Serve image.""" + hass = request.app[KEY_HASS] + image_storage = hass.data[DATA_IMAGES] + image_data = image_storage.get(filename) + + if image_data is None: + raise web.HTTPNotFound + + return web.Response( + body=image_data.data, + content_type=image_data.mime_type, + ) diff --git a/homeassistant/components/ai_task/const.py b/homeassistant/components/ai_task/const.py index 09948e9b673..b62f8002ecf 100644 --- a/homeassistant/components/ai_task/const.py +++ b/homeassistant/components/ai_task/const.py @@ -12,12 +12,18 @@ if TYPE_CHECKING: from . import AITaskPreferences from .entity import AITaskEntity + from .task import ImageData DOMAIN = "ai_task" DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN) DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences") +DATA_IMAGES: HassKey[dict[str, ImageData]] = HassKey(f"{DOMAIN}_images") + +IMAGE_EXPIRY_TIME = 60 * 60 # 1 hour +MAX_IMAGES = 20 SERVICE_GENERATE_DATA = "generate_data" +SERVICE_GENERATE_IMAGE = "generate_image" ATTR_INSTRUCTIONS: Final = "instructions" ATTR_TASK_NAME: Final = "task_name" @@ -38,3 +44,6 @@ class AITaskEntityFeature(IntFlag): SUPPORT_ATTACHMENTS = 2 """Support attachments with generate data.""" + + GENERATE_IMAGE = 4 + """Generate images based on instructions.""" diff --git a/homeassistant/components/ai_task/entity.py b/homeassistant/components/ai_task/entity.py index 4c5cd186943..5b11fe95f28 100644 --- a/homeassistant/components/ai_task/entity.py +++ b/homeassistant/components/ai_task/entity.py @@ -18,7 +18,7 @@ from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.util import dt as dt_util from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN, AITaskEntityFeature -from .task import GenDataTask, GenDataTaskResult +from .task import GenDataTask, GenDataTaskResult, GenImageTask, GenImageTaskResult class AITaskEntity(RestoreEntity): @@ -57,7 +57,7 @@ class AITaskEntity(RestoreEntity): async def _async_get_ai_task_chat_log( self, session: ChatSession, - task: GenDataTask, + task: GenDataTask | GenImageTask, ) -> AsyncGenerator[ChatLog]: """Context manager used to manage the ChatLog used during an AI Task.""" # pylint: disable-next=contextmanager-generator-missing-cleanup @@ -104,3 +104,23 @@ class AITaskEntity(RestoreEntity): ) -> GenDataTaskResult: """Handle a gen data task.""" raise NotImplementedError + + @final + async def internal_async_generate_image( + self, + session: ChatSession, + task: GenImageTask, + ) -> GenImageTaskResult: + """Run a gen image task.""" + self.__last_activity = dt_util.utcnow().isoformat() + self.async_write_ha_state() + async with self._async_get_ai_task_chat_log(session, task) as chat_log: + return await self._async_generate_image(task, chat_log) + + async def _async_generate_image( + self, + task: GenImageTask, + chat_log: ChatLog, + ) -> GenImageTaskResult: + """Handle a gen image task.""" + raise NotImplementedError diff --git a/homeassistant/components/ai_task/icons.json b/homeassistant/components/ai_task/icons.json index 24233372312..2765402abf8 100644 --- a/homeassistant/components/ai_task/icons.json +++ b/homeassistant/components/ai_task/icons.json @@ -7,6 +7,9 @@ "services": { "generate_data": { "service": "mdi:file-star-four-points-outline" + }, + "generate_image": { + "service": "mdi:star-four-points-box-outline" } } } diff --git a/homeassistant/components/ai_task/manifest.json b/homeassistant/components/ai_task/manifest.json index d05faf18055..9e2eec4651d 100644 --- a/homeassistant/components/ai_task/manifest.json +++ b/homeassistant/components/ai_task/manifest.json @@ -1,7 +1,7 @@ { "domain": "ai_task", "name": "AI Task", - "after_dependencies": ["camera"], + "after_dependencies": ["camera", "http"], "codeowners": ["@home-assistant/core"], "dependencies": ["conversation", "media_source"], "documentation": "https://www.home-assistant.io/integrations/ai_task", diff --git a/homeassistant/components/ai_task/media_source.py b/homeassistant/components/ai_task/media_source.py new file mode 100644 index 00000000000..08d3a29e95f --- /dev/null +++ b/homeassistant/components/ai_task/media_source.py @@ -0,0 +1,81 @@ +"""Expose images as media sources.""" + +from __future__ import annotations + +import logging + +from homeassistant.components.media_player import BrowseError, MediaClass +from homeassistant.components.media_source import ( + BrowseMediaSource, + MediaSource, + MediaSourceItem, + PlayMedia, + Unresolvable, +) +from homeassistant.core import HomeAssistant + +from .const import DATA_IMAGES, DOMAIN + +_LOGGER = logging.getLogger(__name__) + + +async def async_get_media_source(hass: HomeAssistant) -> ImageMediaSource: + """Set up image media source.""" + _LOGGER.debug("Setting up image media source") + return ImageMediaSource(hass) + + +class ImageMediaSource(MediaSource): + """Provide images as media sources.""" + + name: str = "AI Generated Images" + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize ImageMediaSource.""" + super().__init__(DOMAIN) + self.hass = hass + + async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia: + """Resolve media to a url.""" + image_storage = self.hass.data[DATA_IMAGES] + image = image_storage.get(item.identifier) + + if image is None: + raise Unresolvable(f"Could not resolve media item: {item.identifier}") + + return PlayMedia(f"/api/{DOMAIN}/images/{item.identifier}", image.mime_type) + + async def async_browse_media( + self, + item: MediaSourceItem, + ) -> BrowseMediaSource: + """Return media.""" + if item.identifier: + raise BrowseError("Unknown item") + + image_storage = self.hass.data[DATA_IMAGES] + + children = [ + BrowseMediaSource( + domain=DOMAIN, + identifier=filename, + media_class=MediaClass.IMAGE, + media_content_type=image.mime_type, + title=image.title or filename, + can_play=True, + can_expand=False, + ) + for filename, image in image_storage.items() + ] + + return BrowseMediaSource( + domain=DOMAIN, + identifier=None, + media_class=MediaClass.APP, + media_content_type="", + title="AI Generated Images", + can_play=False, + can_expand=True, + children_media_class=MediaClass.IMAGE, + children=children, + ) diff --git a/homeassistant/components/ai_task/services.yaml b/homeassistant/components/ai_task/services.yaml index feefa70a30b..17a3b499bfe 100644 --- a/homeassistant/components/ai_task/services.yaml +++ b/homeassistant/components/ai_task/services.yaml @@ -31,3 +31,30 @@ generate_data: media: accept: - "*" +generate_image: + fields: + task_name: + example: "picture of a dog" + required: true + selector: + text: + instructions: + example: "Generate a high quality square image of a dog on transparent background" + required: true + selector: + text: + multiline: true + entity_id: + required: true + selector: + entity: + filter: + domain: ai_task + supported_features: + - ai_task.AITaskEntityFeature.GENERATE_IMAGE + attachments: + required: false + selector: + media: + accept: + - "*" diff --git a/homeassistant/components/ai_task/strings.json b/homeassistant/components/ai_task/strings.json index 261381b7c31..3ec366afb0d 100644 --- a/homeassistant/components/ai_task/strings.json +++ b/homeassistant/components/ai_task/strings.json @@ -25,6 +25,28 @@ "description": "List of files to attach for multi-modal AI analysis." } } + }, + "generate_image": { + "name": "Generate image", + "description": "Uses AI to generate image.", + "fields": { + "task_name": { + "name": "Task name", + "description": "Name of the task." + }, + "instructions": { + "name": "Instructions", + "description": "Instructions that explains the image to be generated." + }, + "entity_id": { + "name": "Entity ID", + "description": "Entity ID to run the task on." + }, + "attachments": { + "name": "Attachments", + "description": "List of files to attach for using as references." + } + } } } } diff --git a/homeassistant/components/ai_task/task.py b/homeassistant/components/ai_task/task.py index 3cc43f8c07a..4efe38425a8 100644 --- a/homeassistant/components/ai_task/task.py +++ b/homeassistant/components/ai_task/task.py @@ -3,6 +3,8 @@ from __future__ import annotations from dataclasses import dataclass +from datetime import datetime +from functools import partial import mimetypes from pathlib import Path import tempfile @@ -11,11 +13,22 @@ from typing import Any import voluptuous as vol from homeassistant.components import camera, conversation, media_source -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import HomeAssistant, ServiceResponse, callback from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers.chat_session import async_get_chat_session +from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session +from homeassistant.helpers.event import async_call_later +from homeassistant.helpers.network import get_url +from homeassistant.util import RE_SANITIZE_FILENAME, slugify -from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature +from .const import ( + DATA_COMPONENT, + DATA_IMAGES, + DATA_PREFERENCES, + DOMAIN, + IMAGE_EXPIRY_TIME, + MAX_IMAGES, + AITaskEntityFeature, +) def _save_camera_snapshot(image: camera.Image) -> Path: @@ -29,43 +42,15 @@ def _save_camera_snapshot(image: camera.Image) -> Path: return Path(temp_file.name) -async def async_generate_data( +async def _resolve_attachments( hass: HomeAssistant, - *, - task_name: str, - entity_id: str | None = None, - instructions: str, - structure: vol.Schema | None = None, + session: ChatSession, attachments: list[dict] | None = None, -) -> GenDataTaskResult: - """Run a task in the AI Task integration.""" - if entity_id is None: - entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id - - if entity_id is None: - raise HomeAssistantError("No entity_id provided and no preferred entity set") - - entity = hass.data[DATA_COMPONENT].get_entity(entity_id) - if entity is None: - raise HomeAssistantError(f"AI Task entity {entity_id} not found") - - if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features: - raise HomeAssistantError( - f"AI Task entity {entity_id} does not support generating data" - ) - - # Resolve attachments +) -> list[conversation.Attachment]: + """Resolve attachments for a task.""" resolved_attachments: list[conversation.Attachment] = [] created_files: list[Path] = [] - if ( - attachments - and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features - ): - raise HomeAssistantError( - f"AI Task entity {entity_id} does not support attachments" - ) - for attachment in attachments or []: media_content_id = attachment["media_content_id"] @@ -104,20 +89,59 @@ async def async_generate_data( ) ) + if not created_files: + return resolved_attachments + + def cleanup_files() -> None: + """Cleanup temporary files.""" + for file in created_files: + file.unlink(missing_ok=True) + + @callback + def cleanup_files_callback() -> None: + """Cleanup temporary files.""" + hass.async_add_executor_job(cleanup_files) + + session.async_on_cleanup(cleanup_files_callback) + + return resolved_attachments + + +async def async_generate_data( + hass: HomeAssistant, + *, + task_name: str, + entity_id: str | None = None, + instructions: str, + structure: vol.Schema | None = None, + attachments: list[dict] | None = None, +) -> GenDataTaskResult: + """Run a data generation task in the AI Task integration.""" + if entity_id is None: + entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id + + if entity_id is None: + raise HomeAssistantError("No entity_id provided and no preferred entity set") + + entity = hass.data[DATA_COMPONENT].get_entity(entity_id) + if entity is None: + raise HomeAssistantError(f"AI Task entity {entity_id} not found") + + if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features: + raise HomeAssistantError( + f"AI Task entity {entity_id} does not support generating data" + ) + + if ( + attachments + and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features + ): + raise HomeAssistantError( + f"AI Task entity {entity_id} does not support attachments" + ) + with async_get_chat_session(hass) as session: - if created_files: - - def cleanup_files() -> None: - """Cleanup temporary files.""" - for file in created_files: - file.unlink(missing_ok=True) - - @callback - def cleanup_files_callback() -> None: - """Cleanup temporary files.""" - hass.async_add_executor_job(cleanup_files) - - session.async_on_cleanup(cleanup_files_callback) + resolved_attachments = await _resolve_attachments(hass, session, attachments) return await entity.internal_async_generate_data( session, @@ -130,6 +154,97 @@ async def async_generate_data( ) +def _cleanup_images(image_storage: dict[str, ImageData], num_to_remove: int) -> None: + """Remove old images to keep the storage size under the limit.""" + if num_to_remove <= 0: + return + + if num_to_remove >= len(image_storage): + image_storage.clear() + return + + sorted_images = sorted( + image_storage.items(), + key=lambda item: item[1].timestamp, + ) + + for filename, _ in sorted_images[:num_to_remove]: + image_storage.pop(filename, None) + + +async def async_generate_image( + hass: HomeAssistant, + *, + task_name: str, + entity_id: str, + instructions: str, + attachments: list[dict] | None = None, +) -> ServiceResponse: + """Run an image generation task in the AI Task integration.""" + entity = hass.data[DATA_COMPONENT].get_entity(entity_id) + if entity is None: + raise HomeAssistantError(f"AI Task entity {entity_id} not found") + + if AITaskEntityFeature.GENERATE_IMAGE not in entity.supported_features: + raise HomeAssistantError( + f"AI Task entity {entity_id} does not support generating images" + ) + + if ( + attachments + and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features + ): + raise HomeAssistantError( + f"AI Task entity {entity_id} does not support attachments" + ) + + with async_get_chat_session(hass) as session: + resolved_attachments = await _resolve_attachments(hass, session, attachments) + + task_result = await entity.internal_async_generate_image( + session, + GenImageTask( + name=task_name, + instructions=instructions, + attachments=resolved_attachments or None, + ), + ) + + service_result = task_result.as_dict() + image_data = service_result.pop("image_data") + if service_result.get("revised_prompt") is None: + service_result["revised_prompt"] = instructions + + image_storage = hass.data[DATA_IMAGES] + + if len(image_storage) + 1 > MAX_IMAGES: + _cleanup_images(image_storage, len(image_storage) + 1 - MAX_IMAGES) + + current_time = datetime.now() + ext = mimetypes.guess_extension(task_result.mime_type, False) or ".png" + sanitized_task_name = RE_SANITIZE_FILENAME.sub("", slugify(task_name)) + filename = f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}" + + image_storage[filename] = ImageData( + data=image_data, + timestamp=int(current_time.timestamp()), + mime_type=task_result.mime_type, + title=service_result["revised_prompt"], + ) + + def _purge_image(filename: str, now: datetime) -> None: + """Remove image from storage.""" + image_storage.pop(filename, None) + + if IMAGE_EXPIRY_TIME > 0: + async_call_later(hass, IMAGE_EXPIRY_TIME, partial(_purge_image, filename)) + + service_result["url"] = get_url(hass) + f"/api/{DOMAIN}/images/{filename}" + service_result["media_source_id"] = f"media-source://{DOMAIN}/images/{filename}" + + return service_result + + @dataclass(slots=True) class GenDataTask: """Gen data task to be processed.""" @@ -167,3 +282,80 @@ class GenDataTaskResult: "conversation_id": self.conversation_id, "data": self.data, } + + +@dataclass(slots=True) +class GenImageTask: + """Gen image task to be processed.""" + + name: str + """Name of the task.""" + + instructions: str + """Instructions on what needs to be done.""" + + attachments: list[conversation.Attachment] | None = None + """List of attachments to go along the instructions.""" + + def __str__(self) -> str: + """Return task as a string.""" + return f"" + + +@dataclass(slots=True) +class GenImageTaskResult: + """Result of gen image task.""" + + image_data: bytes + """Raw image data generated by the model.""" + + conversation_id: str + """Unique identifier for the conversation.""" + + mime_type: str + """MIME type of the generated image.""" + + width: int | None = None + """Width of the generated image, if available.""" + + height: int | None = None + """Height of the generated image, if available.""" + + model: str | None = None + """Model used to generate the image, if available.""" + + revised_prompt: str | None = None + """Revised prompt used to generate the image, if applicable.""" + + def as_dict(self) -> dict[str, Any]: + """Return result as a dict.""" + return { + "image_data": self.image_data, + "conversation_id": self.conversation_id, + "mime_type": self.mime_type, + "width": self.width, + "height": self.height, + "model": self.model, + "revised_prompt": self.revised_prompt, + } + + +@dataclass(slots=True) +class ImageData: + """Image data for stored generated images.""" + + data: bytes + """Raw image data.""" + + timestamp: int + """Timestamp when the image was generated, as a Unix timestamp.""" + + mime_type: str + """MIME type of the image.""" + + title: str + """Title of the image, usually the prompt used to generate it.""" + + def __str__(self) -> str: + """Return image data as a string.""" + return f"" diff --git a/tests/components/ai_task/conftest.py b/tests/components/ai_task/conftest.py index 05d34b15ddc..06f9a56a813 100644 --- a/tests/components/ai_task/conftest.py +++ b/tests/components/ai_task/conftest.py @@ -10,6 +10,8 @@ from homeassistant.components.ai_task import ( AITaskEntityFeature, GenDataTask, GenDataTaskResult, + GenImageTask, + GenImageTaskResult, ) from homeassistant.components.conversation import AssistantContent, ChatLog from homeassistant.config_entries import ConfigEntry, ConfigFlow @@ -36,13 +38,16 @@ class MockAITaskEntity(AITaskEntity): _attr_name = "Test Task Entity" _attr_supported_features = ( - AITaskEntityFeature.GENERATE_DATA | AITaskEntityFeature.SUPPORT_ATTACHMENTS + AITaskEntityFeature.GENERATE_DATA + | AITaskEntityFeature.SUPPORT_ATTACHMENTS + | AITaskEntityFeature.GENERATE_IMAGE ) def __init__(self) -> None: """Initialize the mock entity.""" super().__init__() self.mock_generate_data_tasks = [] + self.mock_generate_image_tasks = [] async def _async_generate_data( self, task: GenDataTask, chat_log: ChatLog @@ -63,6 +68,24 @@ class MockAITaskEntity(AITaskEntity): data=data, ) + async def _async_generate_image( + self, task: GenImageTask, chat_log: ChatLog + ) -> GenImageTaskResult: + """Mock handling of generate image task.""" + self.mock_generate_image_tasks.append(task) + chat_log.async_add_assistant_content_without_tools( + AssistantContent(self.entity_id, "") + ) + return GenImageTaskResult( + conversation_id=chat_log.conversation_id, + image_data=b"mock_image_data", + mime_type="image/png", + width=1536, + height=1024, + model="mock_model", + revised_prompt="mock_revised_prompt", + ) + @pytest.fixture def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: diff --git a/tests/components/ai_task/test_media_source.py b/tests/components/ai_task/test_media_source.py new file mode 100644 index 00000000000..718d7299207 --- /dev/null +++ b/tests/components/ai_task/test_media_source.py @@ -0,0 +1,64 @@ +"""Test ai_task media source.""" + +import pytest + +from homeassistant.components import media_source +from homeassistant.components.ai_task import ImageData +from homeassistant.core import HomeAssistant + + +@pytest.fixture(name="image_id") +async def mock_image_generate(hass: HomeAssistant) -> str: + """Mock image generation and return the image_id.""" + image_storage = hass.data.setdefault("ai_task_images", {}) + filename = "2025-06-15_150640_test_task.png" + image_storage[filename] = ImageData( + data=b"A", + timestamp=1750000000, + mime_type="image/png", + title="Mock Image", + ) + return filename + + +async def test_browsing( + hass: HomeAssistant, init_components: None, image_id: str +) -> None: + """Test browsing image media source.""" + item = await media_source.async_browse_media(hass, "media-source://ai_task") + + assert item is not None + assert item.title == "AI Generated Images" + assert len(item.children) == 1 + assert item.children[0].media_content_type == "image/png" + assert item.children[0].identifier == image_id + assert item.children[0].title == "Mock Image" + + with pytest.raises( + media_source.BrowseError, + match="Unknown item", + ): + await media_source.async_browse_media( + hass, "media-source://ai_task/invalid_path" + ) + + +async def test_resolving( + hass: HomeAssistant, init_components: None, image_id: str +) -> None: + """Test resolving.""" + item = await media_source.async_resolve_media( + hass, f"media-source://ai_task/{image_id}", None + ) + assert item is not None + assert item.url == f"/api/ai_task/images/{image_id}" + assert item.mime_type == "image/png" + + invalid_id = "aabbccddeeff" + with pytest.raises( + media_source.Unresolvable, + match=f"Could not resolve media item: {invalid_id}", + ): + await media_source.async_resolve_media( + hass, f"media-source://ai_task/{invalid_id}", None + ) diff --git a/tests/components/ai_task/test_task.py b/tests/components/ai_task/test_task.py index 7eb75b62bb0..2bebf7b60bb 100644 --- a/tests/components/ai_task/test_task.py +++ b/tests/components/ai_task/test_task.py @@ -1,6 +1,6 @@ """Test tasks for the AI Task integration.""" -from datetime import timedelta +from datetime import datetime, timedelta from pathlib import Path from unittest.mock import patch @@ -9,7 +9,12 @@ import pytest from syrupy.assertion import SnapshotAssertion from homeassistant.components import media_source -from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data +from homeassistant.components.ai_task import ( + AITaskEntityFeature, + ImageData, + async_generate_data, + async_generate_image, +) from homeassistant.components.camera import Image from homeassistant.components.conversation import async_get_chat_log from homeassistant.const import STATE_UNKNOWN @@ -232,7 +237,9 @@ async def test_generate_data_mixed_attachments( hass, dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + timedelta(seconds=1), ) - await hass.async_block_till_done() + await hass.async_block_till_done() # Need several iterations + await hass.async_block_till_done() # because one iteration of the loop + await hass.async_block_till_done() # simply schedules the cleanup # Verify the temporary file cleaned up assert not camera_attachment.path.exists() @@ -242,3 +249,94 @@ async def test_generate_data_mixed_attachments( assert media_attachment.media_content_id == "media-source://media_player/video.mp4" assert media_attachment.mime_type == "video/mp4" assert media_attachment.path == Path("/media/test.mp4") + + +async def test_generate_image( + hass: HomeAssistant, + init_components: None, + mock_ai_task_entity: MockAITaskEntity, +) -> None: + """Test generating image service.""" + with pytest.raises( + HomeAssistantError, match="AI Task entity ai_task.unknown not found" + ): + await async_generate_image( + hass, + task_name="Test Task", + entity_id="ai_task.unknown", + instructions="Test prompt", + ) + + state = hass.states.get(TEST_ENTITY_ID) + assert state is not None + assert state.state == STATE_UNKNOWN + + result = await async_generate_image( + hass, + task_name="Test Task", + entity_id=TEST_ENTITY_ID, + instructions="Test prompt", + ) + assert "image_data" not in result + assert result["media_source_id"].startswith("media-source://ai_task/images/") + assert result["media_source_id"].endswith("_test_task.png") + assert result["url"].startswith("http://10.10.10.10:8123/api/ai_task/images/") + assert result["url"].endswith("_test_task.png") + assert result["mime_type"] == "image/png" + assert result["model"] == "mock_model" + assert result["revised_prompt"] == "mock_revised_prompt" + assert result["height"] == 1024 + assert result["width"] == 1536 + + state = hass.states.get(TEST_ENTITY_ID) + assert state is not None + assert state.state != STATE_UNKNOWN + + mock_ai_task_entity.supported_features = AITaskEntityFeature(0) + with pytest.raises( + HomeAssistantError, + match="AI Task entity ai_task.test_task_entity does not support generating images", + ): + await async_generate_image( + hass, + task_name="Test Task", + entity_id=TEST_ENTITY_ID, + instructions="Test prompt", + ) + + +async def test_image_cleanup( + hass: HomeAssistant, + init_components: None, + mock_ai_task_entity: MockAITaskEntity, +) -> None: + """Test image cache cleanup.""" + image_storage = hass.data.setdefault("ai_task_images", {}) + image_storage.clear() + image_storage.update( + { + str(idx): ImageData( + data=b"mock_image_data", + timestamp=int(datetime.now().timestamp()), + mime_type="image/png", + title="Test Image", + ) + for idx in range(20) + } + ) + assert len(image_storage) == 20 + + result = await async_generate_image( + hass, + task_name="Test Task", + entity_id=TEST_ENTITY_ID, + instructions="Test prompt", + ) + + assert result["url"].split("/")[-1] in image_storage + assert len(image_storage) == 20 + + async_fire_time_changed(hass, dt_util.utcnow() + timedelta(hours=1, seconds=1)) + await hass.async_block_till_done() + + assert len(image_storage) == 19