Add ai_task.generate_image action (#151101)

This commit is contained in:
Denis Shulyaka
2025-08-27 12:41:14 +03:00
committed by GitHub
parent adfdeff84c
commit 20e4d37cc6
12 changed files with 662 additions and 57 deletions

View File

@@ -3,8 +3,10 @@
import logging
from typing import Any
from aiohttp import web
import voluptuous as vol
from homeassistant.components.http import KEY_HASS, HomeAssistantView
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
from homeassistant.core import (
@@ -26,14 +28,24 @@ from .const import (
ATTR_STRUCTURE,
ATTR_TASK_NAME,
DATA_COMPONENT,
DATA_IMAGES,
DATA_PREFERENCES,
DOMAIN,
SERVICE_GENERATE_DATA,
SERVICE_GENERATE_IMAGE,
AITaskEntityFeature,
)
from .entity import AITaskEntity
from .http import async_setup as async_setup_http
from .task import GenDataTask, GenDataTaskResult, async_generate_data
from .task import (
GenDataTask,
GenDataTaskResult,
GenImageTask,
GenImageTaskResult,
ImageData,
async_generate_data,
async_generate_image,
)
__all__ = [
"DOMAIN",
@@ -41,7 +53,11 @@ __all__ = [
"AITaskEntityFeature",
"GenDataTask",
"GenDataTaskResult",
"GenImageTask",
"GenImageTaskResult",
"ImageData",
"async_generate_data",
"async_generate_image",
"async_setup",
"async_setup_entry",
"async_unload_entry",
@@ -78,8 +94,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass)
hass.data[DATA_COMPONENT] = entity_component
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
hass.data[DATA_IMAGES] = {}
await hass.data[DATA_PREFERENCES].async_load()
async_setup_http(hass)
hass.http.register_view(ImageView)
hass.services.async_register(
DOMAIN,
SERVICE_GENERATE_DATA,
@@ -101,6 +119,23 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
supports_response=SupportsResponse.ONLY,
job_type=HassJobType.Coroutinefunction,
)
hass.services.async_register(
DOMAIN,
SERVICE_GENERATE_IMAGE,
async_service_generate_image,
schema=vol.Schema(
{
vol.Required(ATTR_TASK_NAME): cv.string,
vol.Required(ATTR_ENTITY_ID): cv.entity_id,
vol.Required(ATTR_INSTRUCTIONS): cv.string,
vol.Optional(ATTR_ATTACHMENTS): vol.All(
cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})]
),
}
),
supports_response=SupportsResponse.ONLY,
job_type=HassJobType.Coroutinefunction,
)
return True
@@ -115,11 +150,16 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_service_generate_data(call: ServiceCall) -> ServiceResponse:
"""Run the run task service."""
"""Run the data task service."""
result = await async_generate_data(hass=call.hass, **call.data)
return result.as_dict()
async def async_service_generate_image(call: ServiceCall) -> ServiceResponse:
"""Run the image task service."""
return await async_generate_image(hass=call.hass, **call.data)
class AITaskPreferences:
"""AI Task preferences."""
@@ -164,3 +204,29 @@ class AITaskPreferences:
def as_dict(self) -> dict[str, str | None]:
"""Get the current preferences."""
return {key: getattr(self, key) for key in self.KEYS}
class ImageView(HomeAssistantView):
"""View to generated images."""
url = f"/api/{DOMAIN}/images/{{filename}}"
name = f"api:{DOMAIN}/images"
requires_auth = False
async def get(
self,
request: web.Request,
filename: str,
) -> web.Response:
"""Serve image."""
hass = request.app[KEY_HASS]
image_storage = hass.data[DATA_IMAGES]
image_data = image_storage.get(filename)
if image_data is None:
raise web.HTTPNotFound
return web.Response(
body=image_data.data,
content_type=image_data.mime_type,
)

View File

@@ -12,12 +12,18 @@ if TYPE_CHECKING:
from . import AITaskPreferences
from .entity import AITaskEntity
from .task import ImageData
DOMAIN = "ai_task"
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
DATA_IMAGES: HassKey[dict[str, ImageData]] = HassKey(f"{DOMAIN}_images")
IMAGE_EXPIRY_TIME = 60 * 60 # 1 hour
MAX_IMAGES = 20
SERVICE_GENERATE_DATA = "generate_data"
SERVICE_GENERATE_IMAGE = "generate_image"
ATTR_INSTRUCTIONS: Final = "instructions"
ATTR_TASK_NAME: Final = "task_name"
@@ -38,3 +44,6 @@ class AITaskEntityFeature(IntFlag):
SUPPORT_ATTACHMENTS = 2
"""Support attachments with generate data."""
GENERATE_IMAGE = 4
"""Generate images based on instructions."""

View File

@@ -18,7 +18,7 @@ from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util
from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN, AITaskEntityFeature
from .task import GenDataTask, GenDataTaskResult
from .task import GenDataTask, GenDataTaskResult, GenImageTask, GenImageTaskResult
class AITaskEntity(RestoreEntity):
@@ -57,7 +57,7 @@ class AITaskEntity(RestoreEntity):
async def _async_get_ai_task_chat_log(
self,
session: ChatSession,
task: GenDataTask,
task: GenDataTask | GenImageTask,
) -> AsyncGenerator[ChatLog]:
"""Context manager used to manage the ChatLog used during an AI Task."""
# pylint: disable-next=contextmanager-generator-missing-cleanup
@@ -104,3 +104,23 @@ class AITaskEntity(RestoreEntity):
) -> GenDataTaskResult:
"""Handle a gen data task."""
raise NotImplementedError
@final
async def internal_async_generate_image(
self,
session: ChatSession,
task: GenImageTask,
) -> GenImageTaskResult:
"""Run a gen image task."""
self.__last_activity = dt_util.utcnow().isoformat()
self.async_write_ha_state()
async with self._async_get_ai_task_chat_log(session, task) as chat_log:
return await self._async_generate_image(task, chat_log)
async def _async_generate_image(
self,
task: GenImageTask,
chat_log: ChatLog,
) -> GenImageTaskResult:
"""Handle a gen image task."""
raise NotImplementedError

View File

@@ -7,6 +7,9 @@
"services": {
"generate_data": {
"service": "mdi:file-star-four-points-outline"
},
"generate_image": {
"service": "mdi:star-four-points-box-outline"
}
}
}

View File

@@ -1,7 +1,7 @@
{
"domain": "ai_task",
"name": "AI Task",
"after_dependencies": ["camera"],
"after_dependencies": ["camera", "http"],
"codeowners": ["@home-assistant/core"],
"dependencies": ["conversation", "media_source"],
"documentation": "https://www.home-assistant.io/integrations/ai_task",

View File

@@ -0,0 +1,81 @@
"""Expose images as media sources."""
from __future__ import annotations
import logging
from homeassistant.components.media_player import BrowseError, MediaClass
from homeassistant.components.media_source import (
BrowseMediaSource,
MediaSource,
MediaSourceItem,
PlayMedia,
Unresolvable,
)
from homeassistant.core import HomeAssistant
from .const import DATA_IMAGES, DOMAIN
_LOGGER = logging.getLogger(__name__)
async def async_get_media_source(hass: HomeAssistant) -> ImageMediaSource:
"""Set up image media source."""
_LOGGER.debug("Setting up image media source")
return ImageMediaSource(hass)
class ImageMediaSource(MediaSource):
"""Provide images as media sources."""
name: str = "AI Generated Images"
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize ImageMediaSource."""
super().__init__(DOMAIN)
self.hass = hass
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve media to a url."""
image_storage = self.hass.data[DATA_IMAGES]
image = image_storage.get(item.identifier)
if image is None:
raise Unresolvable(f"Could not resolve media item: {item.identifier}")
return PlayMedia(f"/api/{DOMAIN}/images/{item.identifier}", image.mime_type)
async def async_browse_media(
self,
item: MediaSourceItem,
) -> BrowseMediaSource:
"""Return media."""
if item.identifier:
raise BrowseError("Unknown item")
image_storage = self.hass.data[DATA_IMAGES]
children = [
BrowseMediaSource(
domain=DOMAIN,
identifier=filename,
media_class=MediaClass.IMAGE,
media_content_type=image.mime_type,
title=image.title or filename,
can_play=True,
can_expand=False,
)
for filename, image in image_storage.items()
]
return BrowseMediaSource(
domain=DOMAIN,
identifier=None,
media_class=MediaClass.APP,
media_content_type="",
title="AI Generated Images",
can_play=False,
can_expand=True,
children_media_class=MediaClass.IMAGE,
children=children,
)

View File

@@ -31,3 +31,30 @@ generate_data:
media:
accept:
- "*"
generate_image:
fields:
task_name:
example: "picture of a dog"
required: true
selector:
text:
instructions:
example: "Generate a high quality square image of a dog on transparent background"
required: true
selector:
text:
multiline: true
entity_id:
required: true
selector:
entity:
filter:
domain: ai_task
supported_features:
- ai_task.AITaskEntityFeature.GENERATE_IMAGE
attachments:
required: false
selector:
media:
accept:
- "*"

View File

@@ -25,6 +25,28 @@
"description": "List of files to attach for multi-modal AI analysis."
}
}
},
"generate_image": {
"name": "Generate image",
"description": "Uses AI to generate image.",
"fields": {
"task_name": {
"name": "Task name",
"description": "Name of the task."
},
"instructions": {
"name": "Instructions",
"description": "Instructions that explains the image to be generated."
},
"entity_id": {
"name": "Entity ID",
"description": "Entity ID to run the task on."
},
"attachments": {
"name": "Attachments",
"description": "List of files to attach for using as references."
}
}
}
}
}

View File

@@ -3,6 +3,8 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from functools import partial
import mimetypes
from pathlib import Path
import tempfile
@@ -11,11 +13,22 @@ from typing import Any
import voluptuous as vol
from homeassistant.components import camera, conversation, media_source
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import HomeAssistant, ServiceResponse, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.chat_session import async_get_chat_session
from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.network import get_url
from homeassistant.util import RE_SANITIZE_FILENAME, slugify
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
from .const import (
DATA_COMPONENT,
DATA_IMAGES,
DATA_PREFERENCES,
DOMAIN,
IMAGE_EXPIRY_TIME,
MAX_IMAGES,
AITaskEntityFeature,
)
def _save_camera_snapshot(image: camera.Image) -> Path:
@@ -29,43 +42,15 @@ def _save_camera_snapshot(image: camera.Image) -> Path:
return Path(temp_file.name)
async def async_generate_data(
async def _resolve_attachments(
hass: HomeAssistant,
*,
task_name: str,
entity_id: str | None = None,
instructions: str,
structure: vol.Schema | None = None,
session: ChatSession,
attachments: list[dict] | None = None,
) -> GenDataTaskResult:
"""Run a task in the AI Task integration."""
if entity_id is None:
entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
if entity_id is None:
raise HomeAssistantError("No entity_id provided and no preferred entity set")
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
if entity is None:
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support generating data"
)
# Resolve attachments
) -> list[conversation.Attachment]:
"""Resolve attachments for a task."""
resolved_attachments: list[conversation.Attachment] = []
created_files: list[Path] = []
if (
attachments
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
):
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments"
)
for attachment in attachments or []:
media_content_id = attachment["media_content_id"]
@@ -104,20 +89,59 @@ async def async_generate_data(
)
)
if not created_files:
return resolved_attachments
def cleanup_files() -> None:
"""Cleanup temporary files."""
for file in created_files:
file.unlink(missing_ok=True)
@callback
def cleanup_files_callback() -> None:
"""Cleanup temporary files."""
hass.async_add_executor_job(cleanup_files)
session.async_on_cleanup(cleanup_files_callback)
return resolved_attachments
async def async_generate_data(
hass: HomeAssistant,
*,
task_name: str,
entity_id: str | None = None,
instructions: str,
structure: vol.Schema | None = None,
attachments: list[dict] | None = None,
) -> GenDataTaskResult:
"""Run a data generation task in the AI Task integration."""
if entity_id is None:
entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
if entity_id is None:
raise HomeAssistantError("No entity_id provided and no preferred entity set")
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
if entity is None:
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support generating data"
)
if (
attachments
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
):
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments"
)
with async_get_chat_session(hass) as session:
if created_files:
def cleanup_files() -> None:
"""Cleanup temporary files."""
for file in created_files:
file.unlink(missing_ok=True)
@callback
def cleanup_files_callback() -> None:
"""Cleanup temporary files."""
hass.async_add_executor_job(cleanup_files)
session.async_on_cleanup(cleanup_files_callback)
resolved_attachments = await _resolve_attachments(hass, session, attachments)
return await entity.internal_async_generate_data(
session,
@@ -130,6 +154,97 @@ async def async_generate_data(
)
def _cleanup_images(image_storage: dict[str, ImageData], num_to_remove: int) -> None:
"""Remove old images to keep the storage size under the limit."""
if num_to_remove <= 0:
return
if num_to_remove >= len(image_storage):
image_storage.clear()
return
sorted_images = sorted(
image_storage.items(),
key=lambda item: item[1].timestamp,
)
for filename, _ in sorted_images[:num_to_remove]:
image_storage.pop(filename, None)
async def async_generate_image(
hass: HomeAssistant,
*,
task_name: str,
entity_id: str,
instructions: str,
attachments: list[dict] | None = None,
) -> ServiceResponse:
"""Run an image generation task in the AI Task integration."""
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
if entity is None:
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
if AITaskEntityFeature.GENERATE_IMAGE not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support generating images"
)
if (
attachments
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
):
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support attachments"
)
with async_get_chat_session(hass) as session:
resolved_attachments = await _resolve_attachments(hass, session, attachments)
task_result = await entity.internal_async_generate_image(
session,
GenImageTask(
name=task_name,
instructions=instructions,
attachments=resolved_attachments or None,
),
)
service_result = task_result.as_dict()
image_data = service_result.pop("image_data")
if service_result.get("revised_prompt") is None:
service_result["revised_prompt"] = instructions
image_storage = hass.data[DATA_IMAGES]
if len(image_storage) + 1 > MAX_IMAGES:
_cleanup_images(image_storage, len(image_storage) + 1 - MAX_IMAGES)
current_time = datetime.now()
ext = mimetypes.guess_extension(task_result.mime_type, False) or ".png"
sanitized_task_name = RE_SANITIZE_FILENAME.sub("", slugify(task_name))
filename = f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}"
image_storage[filename] = ImageData(
data=image_data,
timestamp=int(current_time.timestamp()),
mime_type=task_result.mime_type,
title=service_result["revised_prompt"],
)
def _purge_image(filename: str, now: datetime) -> None:
"""Remove image from storage."""
image_storage.pop(filename, None)
if IMAGE_EXPIRY_TIME > 0:
async_call_later(hass, IMAGE_EXPIRY_TIME, partial(_purge_image, filename))
service_result["url"] = get_url(hass) + f"/api/{DOMAIN}/images/{filename}"
service_result["media_source_id"] = f"media-source://{DOMAIN}/images/{filename}"
return service_result
@dataclass(slots=True)
class GenDataTask:
"""Gen data task to be processed."""
@@ -167,3 +282,80 @@ class GenDataTaskResult:
"conversation_id": self.conversation_id,
"data": self.data,
}
@dataclass(slots=True)
class GenImageTask:
"""Gen image task to be processed."""
name: str
"""Name of the task."""
instructions: str
"""Instructions on what needs to be done."""
attachments: list[conversation.Attachment] | None = None
"""List of attachments to go along the instructions."""
def __str__(self) -> str:
"""Return task as a string."""
return f"<GenImageTask {self.name}: {id(self)}>"
@dataclass(slots=True)
class GenImageTaskResult:
"""Result of gen image task."""
image_data: bytes
"""Raw image data generated by the model."""
conversation_id: str
"""Unique identifier for the conversation."""
mime_type: str
"""MIME type of the generated image."""
width: int | None = None
"""Width of the generated image, if available."""
height: int | None = None
"""Height of the generated image, if available."""
model: str | None = None
"""Model used to generate the image, if available."""
revised_prompt: str | None = None
"""Revised prompt used to generate the image, if applicable."""
def as_dict(self) -> dict[str, Any]:
"""Return result as a dict."""
return {
"image_data": self.image_data,
"conversation_id": self.conversation_id,
"mime_type": self.mime_type,
"width": self.width,
"height": self.height,
"model": self.model,
"revised_prompt": self.revised_prompt,
}
@dataclass(slots=True)
class ImageData:
"""Image data for stored generated images."""
data: bytes
"""Raw image data."""
timestamp: int
"""Timestamp when the image was generated, as a Unix timestamp."""
mime_type: str
"""MIME type of the image."""
title: str
"""Title of the image, usually the prompt used to generate it."""
def __str__(self) -> str:
"""Return image data as a string."""
return f"<ImageData {self.title}: {id(self)}>"

View File

@@ -10,6 +10,8 @@ from homeassistant.components.ai_task import (
AITaskEntityFeature,
GenDataTask,
GenDataTaskResult,
GenImageTask,
GenImageTaskResult,
)
from homeassistant.components.conversation import AssistantContent, ChatLog
from homeassistant.config_entries import ConfigEntry, ConfigFlow
@@ -36,13 +38,16 @@ class MockAITaskEntity(AITaskEntity):
_attr_name = "Test Task Entity"
_attr_supported_features = (
AITaskEntityFeature.GENERATE_DATA | AITaskEntityFeature.SUPPORT_ATTACHMENTS
AITaskEntityFeature.GENERATE_DATA
| AITaskEntityFeature.SUPPORT_ATTACHMENTS
| AITaskEntityFeature.GENERATE_IMAGE
)
def __init__(self) -> None:
"""Initialize the mock entity."""
super().__init__()
self.mock_generate_data_tasks = []
self.mock_generate_image_tasks = []
async def _async_generate_data(
self, task: GenDataTask, chat_log: ChatLog
@@ -63,6 +68,24 @@ class MockAITaskEntity(AITaskEntity):
data=data,
)
async def _async_generate_image(
self, task: GenImageTask, chat_log: ChatLog
) -> GenImageTaskResult:
"""Mock handling of generate image task."""
self.mock_generate_image_tasks.append(task)
chat_log.async_add_assistant_content_without_tools(
AssistantContent(self.entity_id, "")
)
return GenImageTaskResult(
conversation_id=chat_log.conversation_id,
image_data=b"mock_image_data",
mime_type="image/png",
width=1536,
height=1024,
model="mock_model",
revised_prompt="mock_revised_prompt",
)
@pytest.fixture
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:

View File

@@ -0,0 +1,64 @@
"""Test ai_task media source."""
import pytest
from homeassistant.components import media_source
from homeassistant.components.ai_task import ImageData
from homeassistant.core import HomeAssistant
@pytest.fixture(name="image_id")
async def mock_image_generate(hass: HomeAssistant) -> str:
"""Mock image generation and return the image_id."""
image_storage = hass.data.setdefault("ai_task_images", {})
filename = "2025-06-15_150640_test_task.png"
image_storage[filename] = ImageData(
data=b"A",
timestamp=1750000000,
mime_type="image/png",
title="Mock Image",
)
return filename
async def test_browsing(
hass: HomeAssistant, init_components: None, image_id: str
) -> None:
"""Test browsing image media source."""
item = await media_source.async_browse_media(hass, "media-source://ai_task")
assert item is not None
assert item.title == "AI Generated Images"
assert len(item.children) == 1
assert item.children[0].media_content_type == "image/png"
assert item.children[0].identifier == image_id
assert item.children[0].title == "Mock Image"
with pytest.raises(
media_source.BrowseError,
match="Unknown item",
):
await media_source.async_browse_media(
hass, "media-source://ai_task/invalid_path"
)
async def test_resolving(
hass: HomeAssistant, init_components: None, image_id: str
) -> None:
"""Test resolving."""
item = await media_source.async_resolve_media(
hass, f"media-source://ai_task/{image_id}", None
)
assert item is not None
assert item.url == f"/api/ai_task/images/{image_id}"
assert item.mime_type == "image/png"
invalid_id = "aabbccddeeff"
with pytest.raises(
media_source.Unresolvable,
match=f"Could not resolve media item: {invalid_id}",
):
await media_source.async_resolve_media(
hass, f"media-source://ai_task/{invalid_id}", None
)

View File

@@ -1,6 +1,6 @@
"""Test tasks for the AI Task integration."""
from datetime import timedelta
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import patch
@@ -9,7 +9,12 @@ import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import media_source
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data
from homeassistant.components.ai_task import (
AITaskEntityFeature,
ImageData,
async_generate_data,
async_generate_image,
)
from homeassistant.components.camera import Image
from homeassistant.components.conversation import async_get_chat_log
from homeassistant.const import STATE_UNKNOWN
@@ -232,7 +237,9 @@ async def test_generate_data_mixed_attachments(
hass,
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + timedelta(seconds=1),
)
await hass.async_block_till_done()
await hass.async_block_till_done() # Need several iterations
await hass.async_block_till_done() # because one iteration of the loop
await hass.async_block_till_done() # simply schedules the cleanup
# Verify the temporary file cleaned up
assert not camera_attachment.path.exists()
@@ -242,3 +249,94 @@ async def test_generate_data_mixed_attachments(
assert media_attachment.media_content_id == "media-source://media_player/video.mp4"
assert media_attachment.mime_type == "video/mp4"
assert media_attachment.path == Path("/media/test.mp4")
async def test_generate_image(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test generating image service."""
with pytest.raises(
HomeAssistantError, match="AI Task entity ai_task.unknown not found"
):
await async_generate_image(
hass,
task_name="Test Task",
entity_id="ai_task.unknown",
instructions="Test prompt",
)
state = hass.states.get(TEST_ENTITY_ID)
assert state is not None
assert state.state == STATE_UNKNOWN
result = await async_generate_image(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
assert "image_data" not in result
assert result["media_source_id"].startswith("media-source://ai_task/images/")
assert result["media_source_id"].endswith("_test_task.png")
assert result["url"].startswith("http://10.10.10.10:8123/api/ai_task/images/")
assert result["url"].endswith("_test_task.png")
assert result["mime_type"] == "image/png"
assert result["model"] == "mock_model"
assert result["revised_prompt"] == "mock_revised_prompt"
assert result["height"] == 1024
assert result["width"] == 1536
state = hass.states.get(TEST_ENTITY_ID)
assert state is not None
assert state.state != STATE_UNKNOWN
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
with pytest.raises(
HomeAssistantError,
match="AI Task entity ai_task.test_task_entity does not support generating images",
):
await async_generate_image(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
async def test_image_cleanup(
hass: HomeAssistant,
init_components: None,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test image cache cleanup."""
image_storage = hass.data.setdefault("ai_task_images", {})
image_storage.clear()
image_storage.update(
{
str(idx): ImageData(
data=b"mock_image_data",
timestamp=int(datetime.now().timestamp()),
mime_type="image/png",
title="Test Image",
)
for idx in range(20)
}
)
assert len(image_storage) == 20
result = await async_generate_image(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
assert result["url"].split("/")[-1] in image_storage
assert len(image_storage) == 20
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(hours=1, seconds=1))
await hass.async_block_till_done()
assert len(image_storage) == 19