Add ai_task.generate_image action (#151101)

This commit is contained in:
Denis Shulyaka
2025-08-27 12:41:14 +03:00
committed by GitHub
parent adfdeff84c
commit 20e4d37cc6
12 changed files with 662 additions and 57 deletions

View File

@@ -3,8 +3,10 @@
import logging import logging
from typing import Any from typing import Any
from aiohttp import web
import voluptuous as vol import voluptuous as vol
from homeassistant.components.http import KEY_HASS, HomeAssistantView
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
from homeassistant.core import ( from homeassistant.core import (
@@ -26,14 +28,24 @@ from .const import (
ATTR_STRUCTURE, ATTR_STRUCTURE,
ATTR_TASK_NAME, ATTR_TASK_NAME,
DATA_COMPONENT, DATA_COMPONENT,
DATA_IMAGES,
DATA_PREFERENCES, DATA_PREFERENCES,
DOMAIN, DOMAIN,
SERVICE_GENERATE_DATA, SERVICE_GENERATE_DATA,
SERVICE_GENERATE_IMAGE,
AITaskEntityFeature, AITaskEntityFeature,
) )
from .entity import AITaskEntity from .entity import AITaskEntity
from .http import async_setup as async_setup_http from .http import async_setup as async_setup_http
from .task import GenDataTask, GenDataTaskResult, async_generate_data from .task import (
GenDataTask,
GenDataTaskResult,
GenImageTask,
GenImageTaskResult,
ImageData,
async_generate_data,
async_generate_image,
)
__all__ = [ __all__ = [
"DOMAIN", "DOMAIN",
@@ -41,7 +53,11 @@ __all__ = [
"AITaskEntityFeature", "AITaskEntityFeature",
"GenDataTask", "GenDataTask",
"GenDataTaskResult", "GenDataTaskResult",
"GenImageTask",
"GenImageTaskResult",
"ImageData",
"async_generate_data", "async_generate_data",
"async_generate_image",
"async_setup", "async_setup",
"async_setup_entry", "async_setup_entry",
"async_unload_entry", "async_unload_entry",
@@ -78,8 +94,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass) entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass)
hass.data[DATA_COMPONENT] = entity_component hass.data[DATA_COMPONENT] = entity_component
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass) hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
hass.data[DATA_IMAGES] = {}
await hass.data[DATA_PREFERENCES].async_load() await hass.data[DATA_PREFERENCES].async_load()
async_setup_http(hass) async_setup_http(hass)
hass.http.register_view(ImageView)
hass.services.async_register( hass.services.async_register(
DOMAIN, DOMAIN,
SERVICE_GENERATE_DATA, SERVICE_GENERATE_DATA,
@@ -101,6 +119,23 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
supports_response=SupportsResponse.ONLY, supports_response=SupportsResponse.ONLY,
job_type=HassJobType.Coroutinefunction, job_type=HassJobType.Coroutinefunction,
) )
hass.services.async_register(
DOMAIN,
SERVICE_GENERATE_IMAGE,
async_service_generate_image,
schema=vol.Schema(
{
vol.Required(ATTR_TASK_NAME): cv.string,
vol.Required(ATTR_ENTITY_ID): cv.entity_id,
vol.Required(ATTR_INSTRUCTIONS): cv.string,
vol.Optional(ATTR_ATTACHMENTS): vol.All(
cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})]
),
}
),
supports_response=SupportsResponse.ONLY,
job_type=HassJobType.Coroutinefunction,
)
return True return True
@@ -115,11 +150,16 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_service_generate_data(call: ServiceCall) -> ServiceResponse: async def async_service_generate_data(call: ServiceCall) -> ServiceResponse:
"""Run the run task service.""" """Run the data task service."""
result = await async_generate_data(hass=call.hass, **call.data) result = await async_generate_data(hass=call.hass, **call.data)
return result.as_dict() return result.as_dict()
async def async_service_generate_image(call: ServiceCall) -> ServiceResponse:
"""Run the image task service."""
return await async_generate_image(hass=call.hass, **call.data)
class AITaskPreferences: class AITaskPreferences:
"""AI Task preferences.""" """AI Task preferences."""
@@ -164,3 +204,29 @@ class AITaskPreferences:
def as_dict(self) -> dict[str, str | None]: def as_dict(self) -> dict[str, str | None]:
"""Get the current preferences.""" """Get the current preferences."""
return {key: getattr(self, key) for key in self.KEYS} return {key: getattr(self, key) for key in self.KEYS}
class ImageView(HomeAssistantView):
"""View to generated images."""
url = f"/api/{DOMAIN}/images/{{filename}}"
name = f"api:{DOMAIN}/images"
requires_auth = False
async def get(
self,
request: web.Request,
filename: str,
) -> web.Response:
"""Serve image."""
hass = request.app[KEY_HASS]
image_storage = hass.data[DATA_IMAGES]
image_data = image_storage.get(filename)
if image_data is None:
raise web.HTTPNotFound
return web.Response(
body=image_data.data,
content_type=image_data.mime_type,
)

View File

@@ -12,12 +12,18 @@ if TYPE_CHECKING:
from . import AITaskPreferences from . import AITaskPreferences
from .entity import AITaskEntity from .entity import AITaskEntity
from .task import ImageData
DOMAIN = "ai_task" DOMAIN = "ai_task"
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN) DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences") DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
DATA_IMAGES: HassKey[dict[str, ImageData]] = HassKey(f"{DOMAIN}_images")
IMAGE_EXPIRY_TIME = 60 * 60 # 1 hour
MAX_IMAGES = 20
SERVICE_GENERATE_DATA = "generate_data" SERVICE_GENERATE_DATA = "generate_data"
SERVICE_GENERATE_IMAGE = "generate_image"
ATTR_INSTRUCTIONS: Final = "instructions" ATTR_INSTRUCTIONS: Final = "instructions"
ATTR_TASK_NAME: Final = "task_name" ATTR_TASK_NAME: Final = "task_name"
@@ -38,3 +44,6 @@ class AITaskEntityFeature(IntFlag):
SUPPORT_ATTACHMENTS = 2 SUPPORT_ATTACHMENTS = 2
"""Support attachments with generate data.""" """Support attachments with generate data."""
GENERATE_IMAGE = 4
"""Generate images based on instructions."""

View File

@@ -18,7 +18,7 @@ from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN, AITaskEntityFeature from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN, AITaskEntityFeature
from .task import GenDataTask, GenDataTaskResult from .task import GenDataTask, GenDataTaskResult, GenImageTask, GenImageTaskResult
class AITaskEntity(RestoreEntity): class AITaskEntity(RestoreEntity):
@@ -57,7 +57,7 @@ class AITaskEntity(RestoreEntity):
async def _async_get_ai_task_chat_log( async def _async_get_ai_task_chat_log(
self, self,
session: ChatSession, session: ChatSession,
task: GenDataTask, task: GenDataTask | GenImageTask,
) -> AsyncGenerator[ChatLog]: ) -> AsyncGenerator[ChatLog]:
"""Context manager used to manage the ChatLog used during an AI Task.""" """Context manager used to manage the ChatLog used during an AI Task."""
# pylint: disable-next=contextmanager-generator-missing-cleanup # pylint: disable-next=contextmanager-generator-missing-cleanup
@@ -104,3 +104,23 @@ class AITaskEntity(RestoreEntity):
) -> GenDataTaskResult: ) -> GenDataTaskResult:
"""Handle a gen data task.""" """Handle a gen data task."""
raise NotImplementedError raise NotImplementedError
@final
async def internal_async_generate_image(
self,
session: ChatSession,
task: GenImageTask,
) -> GenImageTaskResult:
"""Run a gen image task."""
self.__last_activity = dt_util.utcnow().isoformat()
self.async_write_ha_state()
async with self._async_get_ai_task_chat_log(session, task) as chat_log:
return await self._async_generate_image(task, chat_log)
async def _async_generate_image(
self,
task: GenImageTask,
chat_log: ChatLog,
) -> GenImageTaskResult:
"""Handle a gen image task."""
raise NotImplementedError

View File

@@ -7,6 +7,9 @@
"services": { "services": {
"generate_data": { "generate_data": {
"service": "mdi:file-star-four-points-outline" "service": "mdi:file-star-four-points-outline"
},
"generate_image": {
"service": "mdi:star-four-points-box-outline"
} }
} }
} }

View File

@@ -1,7 +1,7 @@
{ {
"domain": "ai_task", "domain": "ai_task",
"name": "AI Task", "name": "AI Task",
"after_dependencies": ["camera"], "after_dependencies": ["camera", "http"],
"codeowners": ["@home-assistant/core"], "codeowners": ["@home-assistant/core"],
"dependencies": ["conversation", "media_source"], "dependencies": ["conversation", "media_source"],
"documentation": "https://www.home-assistant.io/integrations/ai_task", "documentation": "https://www.home-assistant.io/integrations/ai_task",

View File

@@ -0,0 +1,81 @@
"""Expose images as media sources."""
from __future__ import annotations
import logging
from homeassistant.components.media_player import BrowseError, MediaClass
from homeassistant.components.media_source import (
BrowseMediaSource,
MediaSource,
MediaSourceItem,
PlayMedia,
Unresolvable,
)
from homeassistant.core import HomeAssistant
from .const import DATA_IMAGES, DOMAIN
_LOGGER = logging.getLogger(__name__)
async def async_get_media_source(hass: HomeAssistant) -> ImageMediaSource:
"""Set up image media source."""
_LOGGER.debug("Setting up image media source")
return ImageMediaSource(hass)
class ImageMediaSource(MediaSource):
"""Provide images as media sources."""
name: str = "AI Generated Images"
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize ImageMediaSource."""
super().__init__(DOMAIN)
self.hass = hass
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve media to a url."""
image_storage = self.hass.data[DATA_IMAGES]
image = image_storage.get(item.identifier)
if image is None:
raise Unresolvable(f"Could not resolve media item: {item.identifier}")
return PlayMedia(f"/api/{DOMAIN}/images/{item.identifier}", image.mime_type)
async def async_browse_media(
self,
item: MediaSourceItem,
) -> BrowseMediaSource:
"""Return media."""
if item.identifier:
raise BrowseError("Unknown item")
image_storage = self.hass.data[DATA_IMAGES]
children = [
BrowseMediaSource(
domain=DOMAIN,
identifier=filename,
media_class=MediaClass.IMAGE,
media_content_type=image.mime_type,
title=image.title or filename,
can_play=True,
can_expand=False,
)
for filename, image in image_storage.items()
]
return BrowseMediaSource(
domain=DOMAIN,
identifier=None,
media_class=MediaClass.APP,
media_content_type="",
title="AI Generated Images",
can_play=False,
can_expand=True,
children_media_class=MediaClass.IMAGE,
children=children,
)

View File

@@ -31,3 +31,30 @@ generate_data:
media: media:
accept: accept:
- "*" - "*"
generate_image:
fields:
task_name:
example: "picture of a dog"
required: true
selector:
text:
instructions:
example: "Generate a high quality square image of a dog on transparent background"
required: true
selector:
text:
multiline: true
entity_id:
required: true
selector:
entity:
filter:
domain: ai_task
supported_features:
- ai_task.AITaskEntityFeature.GENERATE_IMAGE
attachments:
required: false
selector:
media:
accept:
- "*"

View File

@@ -25,6 +25,28 @@
"description": "List of files to attach for multi-modal AI analysis." "description": "List of files to attach for multi-modal AI analysis."
} }
} }
},
"generate_image": {
"name": "Generate image",
"description": "Uses AI to generate image.",
"fields": {
"task_name": {
"name": "Task name",
"description": "Name of the task."
},
"instructions": {
"name": "Instructions",
"description": "Instructions that explains the image to be generated."
},
"entity_id": {
"name": "Entity ID",
"description": "Entity ID to run the task on."
},
"attachments": {
"name": "Attachments",
"description": "List of files to attach for using as references."
}
}
} }
} }
} }

View File

@@ -3,6 +3,8 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime
from functools import partial
import mimetypes import mimetypes
from pathlib import Path from pathlib import Path
import tempfile import tempfile
@@ -11,11 +13,22 @@ from typing import Any
import voluptuous as vol import voluptuous as vol
from homeassistant.components import camera, conversation, media_source from homeassistant.components import camera, conversation, media_source
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, ServiceResponse, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.chat_session import async_get_chat_session from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.network import get_url
from homeassistant.util import RE_SANITIZE_FILENAME, slugify
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature from .const import (
DATA_COMPONENT,
DATA_IMAGES,
DATA_PREFERENCES,
DOMAIN,
IMAGE_EXPIRY_TIME,
MAX_IMAGES,
AITaskEntityFeature,
)
def _save_camera_snapshot(image: camera.Image) -> Path: def _save_camera_snapshot(image: camera.Image) -> Path:
@@ -29,43 +42,15 @@ def _save_camera_snapshot(image: camera.Image) -> Path:
return Path(temp_file.name) return Path(temp_file.name)
async def async_generate_data( async def _resolve_attachments(
hass: HomeAssistant, hass: HomeAssistant,
*, session: ChatSession,
task_name: str,
entity_id: str | None = None,
instructions: str,
structure: vol.Schema | None = None,
attachments: list[dict] | None = None, attachments: list[dict] | None = None,
) -> GenDataTaskResult: ) -> list[conversation.Attachment]:
"""Run a task in the AI Task integration.""" """Resolve attachments for a task."""
if entity_id is None:
entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
if entity_id is None:
raise HomeAssistantError("No entity_id provided and no preferred entity set")
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
if entity is None:
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support generating data"
)
# Resolve attachments
resolved_attachments: list[conversation.Attachment] = [] resolved_attachments: list[conversation.Attachment] = []
created_files: list[Path] = [] created_files: list[Path] = []
if (
attachments
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
):
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments"
)
for attachment in attachments or []: for attachment in attachments or []:
media_content_id = attachment["media_content_id"] media_content_id = attachment["media_content_id"]
@@ -104,8 +89,8 @@ async def async_generate_data(
) )
) )
with async_get_chat_session(hass) as session: if not created_files:
if created_files: return resolved_attachments
def cleanup_files() -> None: def cleanup_files() -> None:
"""Cleanup temporary files.""" """Cleanup temporary files."""
@@ -119,6 +104,45 @@ async def async_generate_data(
session.async_on_cleanup(cleanup_files_callback) session.async_on_cleanup(cleanup_files_callback)
return resolved_attachments
async def async_generate_data(
hass: HomeAssistant,
*,
task_name: str,
entity_id: str | None = None,
instructions: str,
structure: vol.Schema | None = None,
attachments: list[dict] | None = None,
) -> GenDataTaskResult:
"""Run a data generation task in the AI Task integration."""
if entity_id is None:
entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
if entity_id is None:
raise HomeAssistantError("No entity_id provided and no preferred entity set")
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
if entity is None:
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support generating data"
)
if (
attachments
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
):
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments"
)
with async_get_chat_session(hass) as session:
resolved_attachments = await _resolve_attachments(hass, session, attachments)
return await entity.internal_async_generate_data( return await entity.internal_async_generate_data(
session, session,
GenDataTask( GenDataTask(
@@ -130,6 +154,97 @@ async def async_generate_data(
) )
def _cleanup_images(image_storage: dict[str, ImageData], num_to_remove: int) -> None:
"""Remove old images to keep the storage size under the limit."""
if num_to_remove <= 0:
return
if num_to_remove >= len(image_storage):
image_storage.clear()
return
sorted_images = sorted(
image_storage.items(),
key=lambda item: item[1].timestamp,
)
for filename, _ in sorted_images[:num_to_remove]:
image_storage.pop(filename, None)
async def async_generate_image(
hass: HomeAssistant,
*,
task_name: str,
entity_id: str,
instructions: str,
attachments: list[dict] | None = None,
) -> ServiceResponse:
"""Run an image generation task in the AI Task integration."""
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
if entity is None:
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
if AITaskEntityFeature.GENERATE_IMAGE not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support generating images"
)
if (
attachments
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
):
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments"
)
with async_get_chat_session(hass) as session:
resolved_attachments = await _resolve_attachments(hass, session, attachments)
task_result = await entity.internal_async_generate_image(
session,
GenImageTask(
name=task_name,
instructions=instructions,
attachments=resolved_attachments or None,
),
)
service_result = task_result.as_dict()
image_data = service_result.pop("image_data")
if service_result.get("revised_prompt") is None:
service_result["revised_prompt"] = instructions
image_storage = hass.data[DATA_IMAGES]
if len(image_storage) + 1 > MAX_IMAGES:
_cleanup_images(image_storage, len(image_storage) + 1 - MAX_IMAGES)
current_time = datetime.now()
ext = mimetypes.guess_extension(task_result.mime_type, False) or ".png"
sanitized_task_name = RE_SANITIZE_FILENAME.sub("", slugify(task_name))
filename = f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}"
image_storage[filename] = ImageData(
data=image_data,
timestamp=int(current_time.timestamp()),
mime_type=task_result.mime_type,
title=service_result["revised_prompt"],
)
def _purge_image(filename: str, now: datetime) -> None:
"""Remove image from storage."""
image_storage.pop(filename, None)
if IMAGE_EXPIRY_TIME > 0:
async_call_later(hass, IMAGE_EXPIRY_TIME, partial(_purge_image, filename))
service_result["url"] = get_url(hass) + f"/api/{DOMAIN}/images/{filename}"
service_result["media_source_id"] = f"media-source://{DOMAIN}/images/{filename}"
return service_result
@dataclass(slots=True) @dataclass(slots=True)
class GenDataTask: class GenDataTask:
"""Gen data task to be processed.""" """Gen data task to be processed."""
@@ -167,3 +282,80 @@ class GenDataTaskResult:
"conversation_id": self.conversation_id, "conversation_id": self.conversation_id,
"data": self.data, "data": self.data,
} }
@dataclass(slots=True)
class GenImageTask:
"""Gen image task to be processed."""
name: str
"""Name of the task."""
instructions: str
"""Instructions on what needs to be done."""
attachments: list[conversation.Attachment] | None = None
"""List of attachments to go along the instructions."""
def __str__(self) -> str:
"""Return task as a string."""
return f"<GenImageTask {self.name}: {id(self)}>"
@dataclass(slots=True)
class GenImageTaskResult:
"""Result of gen image task."""
image_data: bytes
"""Raw image data generated by the model."""
conversation_id: str
"""Unique identifier for the conversation."""
mime_type: str
"""MIME type of the generated image."""
width: int | None = None
"""Width of the generated image, if available."""
height: int | None = None
"""Height of the generated image, if available."""
model: str | None = None
"""Model used to generate the image, if available."""
revised_prompt: str | None = None
"""Revised prompt used to generate the image, if applicable."""
def as_dict(self) -> dict[str, Any]:
"""Return result as a dict."""
return {
"image_data": self.image_data,
"conversation_id": self.conversation_id,
"mime_type": self.mime_type,
"width": self.width,
"height": self.height,
"model": self.model,
"revised_prompt": self.revised_prompt,
}
@dataclass(slots=True)
class ImageData:
"""Image data for stored generated images."""
data: bytes
"""Raw image data."""
timestamp: int
"""Timestamp when the image was generated, as a Unix timestamp."""
mime_type: str
"""MIME type of the image."""
title: str
"""Title of the image, usually the prompt used to generate it."""
def __str__(self) -> str:
"""Return image data as a string."""
return f"<ImageData {self.title}: {id(self)}>"

View File

@@ -10,6 +10,8 @@ from homeassistant.components.ai_task import (
AITaskEntityFeature, AITaskEntityFeature,
GenDataTask, GenDataTask,
GenDataTaskResult, GenDataTaskResult,
GenImageTask,
GenImageTaskResult,
) )
from homeassistant.components.conversation import AssistantContent, ChatLog from homeassistant.components.conversation import AssistantContent, ChatLog
from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.config_entries import ConfigEntry, ConfigFlow
@@ -36,13 +38,16 @@ class MockAITaskEntity(AITaskEntity):
_attr_name = "Test Task Entity" _attr_name = "Test Task Entity"
_attr_supported_features = ( _attr_supported_features = (
AITaskEntityFeature.GENERATE_DATA | AITaskEntityFeature.SUPPORT_ATTACHMENTS AITaskEntityFeature.GENERATE_DATA
| AITaskEntityFeature.SUPPORT_ATTACHMENTS
| AITaskEntityFeature.GENERATE_IMAGE
) )
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the mock entity.""" """Initialize the mock entity."""
super().__init__() super().__init__()
self.mock_generate_data_tasks = [] self.mock_generate_data_tasks = []
self.mock_generate_image_tasks = []
async def _async_generate_data( async def _async_generate_data(
self, task: GenDataTask, chat_log: ChatLog self, task: GenDataTask, chat_log: ChatLog
@@ -63,6 +68,24 @@ class MockAITaskEntity(AITaskEntity):
data=data, data=data,
) )
async def _async_generate_image(
self, task: GenImageTask, chat_log: ChatLog
) -> GenImageTaskResult:
"""Mock handling of generate image task."""
self.mock_generate_image_tasks.append(task)
chat_log.async_add_assistant_content_without_tools(
AssistantContent(self.entity_id, "")
)
return GenImageTaskResult(
conversation_id=chat_log.conversation_id,
image_data=b"mock_image_data",
mime_type="image/png",
width=1536,
height=1024,
model="mock_model",
revised_prompt="mock_revised_prompt",
)
@pytest.fixture @pytest.fixture
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:

View File

@@ -0,0 +1,64 @@
"""Test ai_task media source."""
import pytest
from homeassistant.components import media_source
from homeassistant.components.ai_task import ImageData
from homeassistant.core import HomeAssistant
@pytest.fixture(name="image_id")
async def mock_image_generate(hass: HomeAssistant) -> str:
"""Mock image generation and return the image_id."""
image_storage = hass.data.setdefault("ai_task_images", {})
filename = "2025-06-15_150640_test_task.png"
image_storage[filename] = ImageData(
data=b"A",
timestamp=1750000000,
mime_type="image/png",
title="Mock Image",
)
return filename
async def test_browsing(
hass: HomeAssistant, init_components: None, image_id: str
) -> None:
"""Test browsing image media source."""
item = await media_source.async_browse_media(hass, "media-source://ai_task")
assert item is not None
assert item.title == "AI Generated Images"
assert len(item.children) == 1
assert item.children[0].media_content_type == "image/png"
assert item.children[0].identifier == image_id
assert item.children[0].title == "Mock Image"
with pytest.raises(
media_source.BrowseError,
match="Unknown item",
):
await media_source.async_browse_media(
hass, "media-source://ai_task/invalid_path"
)
async def test_resolving(
hass: HomeAssistant, init_components: None, image_id: str
) -> None:
"""Test resolving."""
item = await media_source.async_resolve_media(
hass, f"media-source://ai_task/{image_id}", None
)
assert item is not None
assert item.url == f"/api/ai_task/images/{image_id}"
assert item.mime_type == "image/png"
invalid_id = "aabbccddeeff"
with pytest.raises(
media_source.Unresolvable,
match=f"Could not resolve media item: {invalid_id}",
):
await media_source.async_resolve_media(
hass, f"media-source://ai_task/{invalid_id}", None
)

View File

@@ -1,6 +1,6 @@
"""Test tasks for the AI Task integration.""" """Test tasks for the AI Task integration."""
from datetime import timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
@@ -9,7 +9,12 @@ import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components import media_source from homeassistant.components import media_source
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data from homeassistant.components.ai_task import (
AITaskEntityFeature,
ImageData,
async_generate_data,
async_generate_image,
)
from homeassistant.components.camera import Image from homeassistant.components.camera import Image
from homeassistant.components.conversation import async_get_chat_log from homeassistant.components.conversation import async_get_chat_log
from homeassistant.const import STATE_UNKNOWN from homeassistant.const import STATE_UNKNOWN
@@ -232,7 +237,9 @@ async def test_generate_data_mixed_attachments(
hass, hass,
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + timedelta(seconds=1), dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + timedelta(seconds=1),
) )
await hass.async_block_till_done() await hass.async_block_till_done() # Need several iterations
await hass.async_block_till_done() # because one iteration of the loop
await hass.async_block_till_done() # simply schedules the cleanup
# Verify the temporary file cleaned up # Verify the temporary file cleaned up
assert not camera_attachment.path.exists() assert not camera_attachment.path.exists()
@@ -242,3 +249,94 @@ async def test_generate_data_mixed_attachments(
assert media_attachment.media_content_id == "media-source://media_player/video.mp4" assert media_attachment.media_content_id == "media-source://media_player/video.mp4"
assert media_attachment.mime_type == "video/mp4" assert media_attachment.mime_type == "video/mp4"
assert media_attachment.path == Path("/media/test.mp4") assert media_attachment.path == Path("/media/test.mp4")
async def test_generate_image(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test generating image service."""
with pytest.raises(
HomeAssistantError, match="AI Task entity ai_task.unknown not found"
):
await async_generate_image(
hass,
task_name="Test Task",
entity_id="ai_task.unknown",
instructions="Test prompt",
)
state = hass.states.get(TEST_ENTITY_ID)
assert state is not None
assert state.state == STATE_UNKNOWN
result = await async_generate_image(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
assert "image_data" not in result
assert result["media_source_id"].startswith("media-source://ai_task/images/")
assert result["media_source_id"].endswith("_test_task.png")
assert result["url"].startswith("http://10.10.10.10:8123/api/ai_task/images/")
assert result["url"].endswith("_test_task.png")
assert result["mime_type"] == "image/png"
assert result["model"] == "mock_model"
assert result["revised_prompt"] == "mock_revised_prompt"
assert result["height"] == 1024
assert result["width"] == 1536
state = hass.states.get(TEST_ENTITY_ID)
assert state is not None
assert state.state != STATE_UNKNOWN
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
with pytest.raises(
HomeAssistantError,
match="AI Task entity ai_task.test_task_entity does not support generating images",
):
await async_generate_image(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
async def test_image_cleanup(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test image cache cleanup."""
image_storage = hass.data.setdefault("ai_task_images", {})
image_storage.clear()
image_storage.update(
{
str(idx): ImageData(
data=b"mock_image_data",
timestamp=int(datetime.now().timestamp()),
mime_type="image/png",
title="Test Image",
)
for idx in range(20)
}
)
assert len(image_storage) == 20
result = await async_generate_image(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
assert result["url"].split("/")[-1] in image_storage
assert len(image_storage) == 20
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(hours=1, seconds=1))
await hass.async_block_till_done()
assert len(image_storage) == 19