mirror of
https://github.com/home-assistant/core.git
synced 2025-09-10 15:21:38 +02:00
Add ai_task.generate_image action (#151101)
This commit is contained in:
@@ -3,8 +3,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.http import KEY_HASS, HomeAssistantView
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
|
||||
from homeassistant.core import (
|
||||
@@ -26,14 +28,24 @@ from .const import (
|
||||
ATTR_STRUCTURE,
|
||||
ATTR_TASK_NAME,
|
||||
DATA_COMPONENT,
|
||||
DATA_IMAGES,
|
||||
DATA_PREFERENCES,
|
||||
DOMAIN,
|
||||
SERVICE_GENERATE_DATA,
|
||||
SERVICE_GENERATE_IMAGE,
|
||||
AITaskEntityFeature,
|
||||
)
|
||||
from .entity import AITaskEntity
|
||||
from .http import async_setup as async_setup_http
|
||||
from .task import GenDataTask, GenDataTaskResult, async_generate_data
|
||||
from .task import (
|
||||
GenDataTask,
|
||||
GenDataTaskResult,
|
||||
GenImageTask,
|
||||
GenImageTaskResult,
|
||||
ImageData,
|
||||
async_generate_data,
|
||||
async_generate_image,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
@@ -41,7 +53,11 @@ __all__ = [
|
||||
"AITaskEntityFeature",
|
||||
"GenDataTask",
|
||||
"GenDataTaskResult",
|
||||
"GenImageTask",
|
||||
"GenImageTaskResult",
|
||||
"ImageData",
|
||||
"async_generate_data",
|
||||
"async_generate_image",
|
||||
"async_setup",
|
||||
"async_setup_entry",
|
||||
"async_unload_entry",
|
||||
@@ -78,8 +94,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass)
|
||||
hass.data[DATA_COMPONENT] = entity_component
|
||||
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
|
||||
hass.data[DATA_IMAGES] = {}
|
||||
await hass.data[DATA_PREFERENCES].async_load()
|
||||
async_setup_http(hass)
|
||||
hass.http.register_view(ImageView)
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
SERVICE_GENERATE_DATA,
|
||||
@@ -101,6 +119,23 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
job_type=HassJobType.Coroutinefunction,
|
||||
)
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
SERVICE_GENERATE_IMAGE,
|
||||
async_service_generate_image,
|
||||
schema=vol.Schema(
|
||||
{
|
||||
vol.Required(ATTR_TASK_NAME): cv.string,
|
||||
vol.Required(ATTR_ENTITY_ID): cv.entity_id,
|
||||
vol.Required(ATTR_INSTRUCTIONS): cv.string,
|
||||
vol.Optional(ATTR_ATTACHMENTS): vol.All(
|
||||
cv.ensure_list, [selector.MediaSelector({"accept": ["*/*"]})]
|
||||
),
|
||||
}
|
||||
),
|
||||
supports_response=SupportsResponse.ONLY,
|
||||
job_type=HassJobType.Coroutinefunction,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@@ -115,11 +150,16 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
|
||||
|
||||
async def async_service_generate_data(call: ServiceCall) -> ServiceResponse:
|
||||
"""Run the run task service."""
|
||||
"""Run the data task service."""
|
||||
result = await async_generate_data(hass=call.hass, **call.data)
|
||||
return result.as_dict()
|
||||
|
||||
|
||||
async def async_service_generate_image(call: ServiceCall) -> ServiceResponse:
|
||||
"""Run the image task service."""
|
||||
return await async_generate_image(hass=call.hass, **call.data)
|
||||
|
||||
|
||||
class AITaskPreferences:
|
||||
"""AI Task preferences."""
|
||||
|
||||
@@ -164,3 +204,29 @@ class AITaskPreferences:
|
||||
def as_dict(self) -> dict[str, str | None]:
|
||||
"""Get the current preferences."""
|
||||
return {key: getattr(self, key) for key in self.KEYS}
|
||||
|
||||
|
||||
class ImageView(HomeAssistantView):
|
||||
"""View to generated images."""
|
||||
|
||||
url = f"/api/{DOMAIN}/images/{{filename}}"
|
||||
name = f"api:{DOMAIN}/images"
|
||||
requires_auth = False
|
||||
|
||||
async def get(
|
||||
self,
|
||||
request: web.Request,
|
||||
filename: str,
|
||||
) -> web.Response:
|
||||
"""Serve image."""
|
||||
hass = request.app[KEY_HASS]
|
||||
image_storage = hass.data[DATA_IMAGES]
|
||||
image_data = image_storage.get(filename)
|
||||
|
||||
if image_data is None:
|
||||
raise web.HTTPNotFound
|
||||
|
||||
return web.Response(
|
||||
body=image_data.data,
|
||||
content_type=image_data.mime_type,
|
||||
)
|
||||
|
@@ -12,12 +12,18 @@ if TYPE_CHECKING:
|
||||
|
||||
from . import AITaskPreferences
|
||||
from .entity import AITaskEntity
|
||||
from .task import ImageData
|
||||
|
||||
DOMAIN = "ai_task"
|
||||
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
|
||||
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
|
||||
DATA_IMAGES: HassKey[dict[str, ImageData]] = HassKey(f"{DOMAIN}_images")
|
||||
|
||||
IMAGE_EXPIRY_TIME = 60 * 60 # 1 hour
|
||||
MAX_IMAGES = 20
|
||||
|
||||
SERVICE_GENERATE_DATA = "generate_data"
|
||||
SERVICE_GENERATE_IMAGE = "generate_image"
|
||||
|
||||
ATTR_INSTRUCTIONS: Final = "instructions"
|
||||
ATTR_TASK_NAME: Final = "task_name"
|
||||
@@ -38,3 +44,6 @@ class AITaskEntityFeature(IntFlag):
|
||||
|
||||
SUPPORT_ATTACHMENTS = 2
|
||||
"""Support attachments with generate data."""
|
||||
|
||||
GENERATE_IMAGE = 4
|
||||
"""Generate images based on instructions."""
|
||||
|
@@ -18,7 +18,7 @@ from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN, AITaskEntityFeature
|
||||
from .task import GenDataTask, GenDataTaskResult
|
||||
from .task import GenDataTask, GenDataTaskResult, GenImageTask, GenImageTaskResult
|
||||
|
||||
|
||||
class AITaskEntity(RestoreEntity):
|
||||
@@ -57,7 +57,7 @@ class AITaskEntity(RestoreEntity):
|
||||
async def _async_get_ai_task_chat_log(
|
||||
self,
|
||||
session: ChatSession,
|
||||
task: GenDataTask,
|
||||
task: GenDataTask | GenImageTask,
|
||||
) -> AsyncGenerator[ChatLog]:
|
||||
"""Context manager used to manage the ChatLog used during an AI Task."""
|
||||
# pylint: disable-next=contextmanager-generator-missing-cleanup
|
||||
@@ -104,3 +104,23 @@ class AITaskEntity(RestoreEntity):
|
||||
) -> GenDataTaskResult:
|
||||
"""Handle a gen data task."""
|
||||
raise NotImplementedError
|
||||
|
||||
@final
|
||||
async def internal_async_generate_image(
|
||||
self,
|
||||
session: ChatSession,
|
||||
task: GenImageTask,
|
||||
) -> GenImageTaskResult:
|
||||
"""Run a gen image task."""
|
||||
self.__last_activity = dt_util.utcnow().isoformat()
|
||||
self.async_write_ha_state()
|
||||
async with self._async_get_ai_task_chat_log(session, task) as chat_log:
|
||||
return await self._async_generate_image(task, chat_log)
|
||||
|
||||
async def _async_generate_image(
|
||||
self,
|
||||
task: GenImageTask,
|
||||
chat_log: ChatLog,
|
||||
) -> GenImageTaskResult:
|
||||
"""Handle a gen image task."""
|
||||
raise NotImplementedError
|
||||
|
@@ -7,6 +7,9 @@
|
||||
"services": {
|
||||
"generate_data": {
|
||||
"service": "mdi:file-star-four-points-outline"
|
||||
},
|
||||
"generate_image": {
|
||||
"service": "mdi:star-four-points-box-outline"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"domain": "ai_task",
|
||||
"name": "AI Task",
|
||||
"after_dependencies": ["camera"],
|
||||
"after_dependencies": ["camera", "http"],
|
||||
"codeowners": ["@home-assistant/core"],
|
||||
"dependencies": ["conversation", "media_source"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/ai_task",
|
||||
|
81
homeassistant/components/ai_task/media_source.py
Normal file
81
homeassistant/components/ai_task/media_source.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Expose images as media sources."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from homeassistant.components.media_player import BrowseError, MediaClass
|
||||
from homeassistant.components.media_source import (
|
||||
BrowseMediaSource,
|
||||
MediaSource,
|
||||
MediaSourceItem,
|
||||
PlayMedia,
|
||||
Unresolvable,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import DATA_IMAGES, DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def async_get_media_source(hass: HomeAssistant) -> ImageMediaSource:
|
||||
"""Set up image media source."""
|
||||
_LOGGER.debug("Setting up image media source")
|
||||
return ImageMediaSource(hass)
|
||||
|
||||
|
||||
class ImageMediaSource(MediaSource):
|
||||
"""Provide images as media sources."""
|
||||
|
||||
name: str = "AI Generated Images"
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize ImageMediaSource."""
|
||||
super().__init__(DOMAIN)
|
||||
self.hass = hass
|
||||
|
||||
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
||||
"""Resolve media to a url."""
|
||||
image_storage = self.hass.data[DATA_IMAGES]
|
||||
image = image_storage.get(item.identifier)
|
||||
|
||||
if image is None:
|
||||
raise Unresolvable(f"Could not resolve media item: {item.identifier}")
|
||||
|
||||
return PlayMedia(f"/api/{DOMAIN}/images/{item.identifier}", image.mime_type)
|
||||
|
||||
async def async_browse_media(
|
||||
self,
|
||||
item: MediaSourceItem,
|
||||
) -> BrowseMediaSource:
|
||||
"""Return media."""
|
||||
if item.identifier:
|
||||
raise BrowseError("Unknown item")
|
||||
|
||||
image_storage = self.hass.data[DATA_IMAGES]
|
||||
|
||||
children = [
|
||||
BrowseMediaSource(
|
||||
domain=DOMAIN,
|
||||
identifier=filename,
|
||||
media_class=MediaClass.IMAGE,
|
||||
media_content_type=image.mime_type,
|
||||
title=image.title or filename,
|
||||
can_play=True,
|
||||
can_expand=False,
|
||||
)
|
||||
for filename, image in image_storage.items()
|
||||
]
|
||||
|
||||
return BrowseMediaSource(
|
||||
domain=DOMAIN,
|
||||
identifier=None,
|
||||
media_class=MediaClass.APP,
|
||||
media_content_type="",
|
||||
title="AI Generated Images",
|
||||
can_play=False,
|
||||
can_expand=True,
|
||||
children_media_class=MediaClass.IMAGE,
|
||||
children=children,
|
||||
)
|
@@ -31,3 +31,30 @@ generate_data:
|
||||
media:
|
||||
accept:
|
||||
- "*"
|
||||
generate_image:
|
||||
fields:
|
||||
task_name:
|
||||
example: "picture of a dog"
|
||||
required: true
|
||||
selector:
|
||||
text:
|
||||
instructions:
|
||||
example: "Generate a high quality square image of a dog on transparent background"
|
||||
required: true
|
||||
selector:
|
||||
text:
|
||||
multiline: true
|
||||
entity_id:
|
||||
required: true
|
||||
selector:
|
||||
entity:
|
||||
filter:
|
||||
domain: ai_task
|
||||
supported_features:
|
||||
- ai_task.AITaskEntityFeature.GENERATE_IMAGE
|
||||
attachments:
|
||||
required: false
|
||||
selector:
|
||||
media:
|
||||
accept:
|
||||
- "*"
|
||||
|
@@ -25,6 +25,28 @@
|
||||
"description": "List of files to attach for multi-modal AI analysis."
|
||||
}
|
||||
}
|
||||
},
|
||||
"generate_image": {
|
||||
"name": "Generate image",
|
||||
"description": "Uses AI to generate image.",
|
||||
"fields": {
|
||||
"task_name": {
|
||||
"name": "Task name",
|
||||
"description": "Name of the task."
|
||||
},
|
||||
"instructions": {
|
||||
"name": "Instructions",
|
||||
"description": "Instructions that explains the image to be generated."
|
||||
},
|
||||
"entity_id": {
|
||||
"name": "Entity ID",
|
||||
"description": "Entity ID to run the task on."
|
||||
},
|
||||
"attachments": {
|
||||
"name": "Attachments",
|
||||
"description": "List of files to attach for using as references."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -3,6 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
@@ -11,11 +13,22 @@ from typing import Any
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import camera, conversation, media_source
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.core import HomeAssistant, ServiceResponse, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.chat_session import async_get_chat_session
|
||||
from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.helpers.network import get_url
|
||||
from homeassistant.util import RE_SANITIZE_FILENAME, slugify
|
||||
|
||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||
from .const import (
|
||||
DATA_COMPONENT,
|
||||
DATA_IMAGES,
|
||||
DATA_PREFERENCES,
|
||||
DOMAIN,
|
||||
IMAGE_EXPIRY_TIME,
|
||||
MAX_IMAGES,
|
||||
AITaskEntityFeature,
|
||||
)
|
||||
|
||||
|
||||
def _save_camera_snapshot(image: camera.Image) -> Path:
|
||||
@@ -29,43 +42,15 @@ def _save_camera_snapshot(image: camera.Image) -> Path:
|
||||
return Path(temp_file.name)
|
||||
|
||||
|
||||
async def async_generate_data(
|
||||
async def _resolve_attachments(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
task_name: str,
|
||||
entity_id: str | None = None,
|
||||
instructions: str,
|
||||
structure: vol.Schema | None = None,
|
||||
session: ChatSession,
|
||||
attachments: list[dict] | None = None,
|
||||
) -> GenDataTaskResult:
|
||||
"""Run a task in the AI Task integration."""
|
||||
if entity_id is None:
|
||||
entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
|
||||
|
||||
if entity_id is None:
|
||||
raise HomeAssistantError("No entity_id provided and no preferred entity set")
|
||||
|
||||
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
|
||||
if entity is None:
|
||||
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
|
||||
|
||||
if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
|
||||
raise HomeAssistantError(
|
||||
f"AI Task entity {entity_id} does not support generating data"
|
||||
)
|
||||
|
||||
# Resolve attachments
|
||||
) -> list[conversation.Attachment]:
|
||||
"""Resolve attachments for a task."""
|
||||
resolved_attachments: list[conversation.Attachment] = []
|
||||
created_files: list[Path] = []
|
||||
|
||||
if (
|
||||
attachments
|
||||
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
|
||||
):
|
||||
raise HomeAssistantError(
|
||||
f"AI Task entity {entity_id} does not support attachments"
|
||||
)
|
||||
|
||||
for attachment in attachments or []:
|
||||
media_content_id = attachment["media_content_id"]
|
||||
|
||||
@@ -104,20 +89,59 @@ async def async_generate_data(
|
||||
)
|
||||
)
|
||||
|
||||
if not created_files:
|
||||
return resolved_attachments
|
||||
|
||||
def cleanup_files() -> None:
|
||||
"""Cleanup temporary files."""
|
||||
for file in created_files:
|
||||
file.unlink(missing_ok=True)
|
||||
|
||||
@callback
|
||||
def cleanup_files_callback() -> None:
|
||||
"""Cleanup temporary files."""
|
||||
hass.async_add_executor_job(cleanup_files)
|
||||
|
||||
session.async_on_cleanup(cleanup_files_callback)
|
||||
|
||||
return resolved_attachments
|
||||
|
||||
|
||||
async def async_generate_data(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
task_name: str,
|
||||
entity_id: str | None = None,
|
||||
instructions: str,
|
||||
structure: vol.Schema | None = None,
|
||||
attachments: list[dict] | None = None,
|
||||
) -> GenDataTaskResult:
|
||||
"""Run a data generation task in the AI Task integration."""
|
||||
if entity_id is None:
|
||||
entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
|
||||
|
||||
if entity_id is None:
|
||||
raise HomeAssistantError("No entity_id provided and no preferred entity set")
|
||||
|
||||
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
|
||||
if entity is None:
|
||||
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
|
||||
|
||||
if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
|
||||
raise HomeAssistantError(
|
||||
f"AI Task entity {entity_id} does not support generating data"
|
||||
)
|
||||
|
||||
if (
|
||||
attachments
|
||||
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
|
||||
):
|
||||
raise HomeAssistantError(
|
||||
f"AI Task entity {entity_id} does not support attachments"
|
||||
)
|
||||
|
||||
with async_get_chat_session(hass) as session:
|
||||
if created_files:
|
||||
|
||||
def cleanup_files() -> None:
|
||||
"""Cleanup temporary files."""
|
||||
for file in created_files:
|
||||
file.unlink(missing_ok=True)
|
||||
|
||||
@callback
|
||||
def cleanup_files_callback() -> None:
|
||||
"""Cleanup temporary files."""
|
||||
hass.async_add_executor_job(cleanup_files)
|
||||
|
||||
session.async_on_cleanup(cleanup_files_callback)
|
||||
resolved_attachments = await _resolve_attachments(hass, session, attachments)
|
||||
|
||||
return await entity.internal_async_generate_data(
|
||||
session,
|
||||
@@ -130,6 +154,97 @@ async def async_generate_data(
|
||||
)
|
||||
|
||||
|
||||
def _cleanup_images(image_storage: dict[str, ImageData], num_to_remove: int) -> None:
|
||||
"""Remove old images to keep the storage size under the limit."""
|
||||
if num_to_remove <= 0:
|
||||
return
|
||||
|
||||
if num_to_remove >= len(image_storage):
|
||||
image_storage.clear()
|
||||
return
|
||||
|
||||
sorted_images = sorted(
|
||||
image_storage.items(),
|
||||
key=lambda item: item[1].timestamp,
|
||||
)
|
||||
|
||||
for filename, _ in sorted_images[:num_to_remove]:
|
||||
image_storage.pop(filename, None)
|
||||
|
||||
|
||||
async def async_generate_image(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
task_name: str,
|
||||
entity_id: str,
|
||||
instructions: str,
|
||||
attachments: list[dict] | None = None,
|
||||
) -> ServiceResponse:
|
||||
"""Run an image generation task in the AI Task integration."""
|
||||
entity = hass.data[DATA_COMPONENT].get_entity(entity_id)
|
||||
if entity is None:
|
||||
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
|
||||
|
||||
if AITaskEntityFeature.GENERATE_IMAGE not in entity.supported_features:
|
||||
raise HomeAssistantError(
|
||||
f"AI Task entity {entity_id} does not support generating images"
|
||||
)
|
||||
|
||||
if (
|
||||
attachments
|
||||
and AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features
|
||||
):
|
||||
raise HomeAssistantError(
|
||||
f"AI Task entity {entity_id} does not support attachments"
|
||||
)
|
||||
|
||||
with async_get_chat_session(hass) as session:
|
||||
resolved_attachments = await _resolve_attachments(hass, session, attachments)
|
||||
|
||||
task_result = await entity.internal_async_generate_image(
|
||||
session,
|
||||
GenImageTask(
|
||||
name=task_name,
|
||||
instructions=instructions,
|
||||
attachments=resolved_attachments or None,
|
||||
),
|
||||
)
|
||||
|
||||
service_result = task_result.as_dict()
|
||||
image_data = service_result.pop("image_data")
|
||||
if service_result.get("revised_prompt") is None:
|
||||
service_result["revised_prompt"] = instructions
|
||||
|
||||
image_storage = hass.data[DATA_IMAGES]
|
||||
|
||||
if len(image_storage) + 1 > MAX_IMAGES:
|
||||
_cleanup_images(image_storage, len(image_storage) + 1 - MAX_IMAGES)
|
||||
|
||||
current_time = datetime.now()
|
||||
ext = mimetypes.guess_extension(task_result.mime_type, False) or ".png"
|
||||
sanitized_task_name = RE_SANITIZE_FILENAME.sub("", slugify(task_name))
|
||||
filename = f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}"
|
||||
|
||||
image_storage[filename] = ImageData(
|
||||
data=image_data,
|
||||
timestamp=int(current_time.timestamp()),
|
||||
mime_type=task_result.mime_type,
|
||||
title=service_result["revised_prompt"],
|
||||
)
|
||||
|
||||
def _purge_image(filename: str, now: datetime) -> None:
|
||||
"""Remove image from storage."""
|
||||
image_storage.pop(filename, None)
|
||||
|
||||
if IMAGE_EXPIRY_TIME > 0:
|
||||
async_call_later(hass, IMAGE_EXPIRY_TIME, partial(_purge_image, filename))
|
||||
|
||||
service_result["url"] = get_url(hass) + f"/api/{DOMAIN}/images/{filename}"
|
||||
service_result["media_source_id"] = f"media-source://{DOMAIN}/images/{filename}"
|
||||
|
||||
return service_result
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GenDataTask:
|
||||
"""Gen data task to be processed."""
|
||||
@@ -167,3 +282,80 @@ class GenDataTaskResult:
|
||||
"conversation_id": self.conversation_id,
|
||||
"data": self.data,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GenImageTask:
|
||||
"""Gen image task to be processed."""
|
||||
|
||||
name: str
|
||||
"""Name of the task."""
|
||||
|
||||
instructions: str
|
||||
"""Instructions on what needs to be done."""
|
||||
|
||||
attachments: list[conversation.Attachment] | None = None
|
||||
"""List of attachments to go along the instructions."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return task as a string."""
|
||||
return f"<GenImageTask {self.name}: {id(self)}>"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GenImageTaskResult:
|
||||
"""Result of gen image task."""
|
||||
|
||||
image_data: bytes
|
||||
"""Raw image data generated by the model."""
|
||||
|
||||
conversation_id: str
|
||||
"""Unique identifier for the conversation."""
|
||||
|
||||
mime_type: str
|
||||
"""MIME type of the generated image."""
|
||||
|
||||
width: int | None = None
|
||||
"""Width of the generated image, if available."""
|
||||
|
||||
height: int | None = None
|
||||
"""Height of the generated image, if available."""
|
||||
|
||||
model: str | None = None
|
||||
"""Model used to generate the image, if available."""
|
||||
|
||||
revised_prompt: str | None = None
|
||||
"""Revised prompt used to generate the image, if applicable."""
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return result as a dict."""
|
||||
return {
|
||||
"image_data": self.image_data,
|
||||
"conversation_id": self.conversation_id,
|
||||
"mime_type": self.mime_type,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"model": self.model,
|
||||
"revised_prompt": self.revised_prompt,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImageData:
|
||||
"""Image data for stored generated images."""
|
||||
|
||||
data: bytes
|
||||
"""Raw image data."""
|
||||
|
||||
timestamp: int
|
||||
"""Timestamp when the image was generated, as a Unix timestamp."""
|
||||
|
||||
mime_type: str
|
||||
"""MIME type of the image."""
|
||||
|
||||
title: str
|
||||
"""Title of the image, usually the prompt used to generate it."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return image data as a string."""
|
||||
return f"<ImageData {self.title}: {id(self)}>"
|
||||
|
@@ -10,6 +10,8 @@ from homeassistant.components.ai_task import (
|
||||
AITaskEntityFeature,
|
||||
GenDataTask,
|
||||
GenDataTaskResult,
|
||||
GenImageTask,
|
||||
GenImageTaskResult,
|
||||
)
|
||||
from homeassistant.components.conversation import AssistantContent, ChatLog
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
@@ -36,13 +38,16 @@ class MockAITaskEntity(AITaskEntity):
|
||||
|
||||
_attr_name = "Test Task Entity"
|
||||
_attr_supported_features = (
|
||||
AITaskEntityFeature.GENERATE_DATA | AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||
AITaskEntityFeature.GENERATE_DATA
|
||||
| AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||
| AITaskEntityFeature.GENERATE_IMAGE
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the mock entity."""
|
||||
super().__init__()
|
||||
self.mock_generate_data_tasks = []
|
||||
self.mock_generate_image_tasks = []
|
||||
|
||||
async def _async_generate_data(
|
||||
self, task: GenDataTask, chat_log: ChatLog
|
||||
@@ -63,6 +68,24 @@ class MockAITaskEntity(AITaskEntity):
|
||||
data=data,
|
||||
)
|
||||
|
||||
async def _async_generate_image(
|
||||
self, task: GenImageTask, chat_log: ChatLog
|
||||
) -> GenImageTaskResult:
|
||||
"""Mock handling of generate image task."""
|
||||
self.mock_generate_image_tasks.append(task)
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(self.entity_id, "")
|
||||
)
|
||||
return GenImageTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
image_data=b"mock_image_data",
|
||||
mime_type="image/png",
|
||||
width=1536,
|
||||
height=1024,
|
||||
model="mock_model",
|
||||
revised_prompt="mock_revised_prompt",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
|
64
tests/components/ai_task/test_media_source.py
Normal file
64
tests/components/ai_task/test_media_source.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Test ai_task media source."""
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.components.ai_task import ImageData
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
@pytest.fixture(name="image_id")
|
||||
async def mock_image_generate(hass: HomeAssistant) -> str:
|
||||
"""Mock image generation and return the image_id."""
|
||||
image_storage = hass.data.setdefault("ai_task_images", {})
|
||||
filename = "2025-06-15_150640_test_task.png"
|
||||
image_storage[filename] = ImageData(
|
||||
data=b"A",
|
||||
timestamp=1750000000,
|
||||
mime_type="image/png",
|
||||
title="Mock Image",
|
||||
)
|
||||
return filename
|
||||
|
||||
|
||||
async def test_browsing(
|
||||
hass: HomeAssistant, init_components: None, image_id: str
|
||||
) -> None:
|
||||
"""Test browsing image media source."""
|
||||
item = await media_source.async_browse_media(hass, "media-source://ai_task")
|
||||
|
||||
assert item is not None
|
||||
assert item.title == "AI Generated Images"
|
||||
assert len(item.children) == 1
|
||||
assert item.children[0].media_content_type == "image/png"
|
||||
assert item.children[0].identifier == image_id
|
||||
assert item.children[0].title == "Mock Image"
|
||||
|
||||
with pytest.raises(
|
||||
media_source.BrowseError,
|
||||
match="Unknown item",
|
||||
):
|
||||
await media_source.async_browse_media(
|
||||
hass, "media-source://ai_task/invalid_path"
|
||||
)
|
||||
|
||||
|
||||
async def test_resolving(
|
||||
hass: HomeAssistant, init_components: None, image_id: str
|
||||
) -> None:
|
||||
"""Test resolving."""
|
||||
item = await media_source.async_resolve_media(
|
||||
hass, f"media-source://ai_task/{image_id}", None
|
||||
)
|
||||
assert item is not None
|
||||
assert item.url == f"/api/ai_task/images/{image_id}"
|
||||
assert item.mime_type == "image/png"
|
||||
|
||||
invalid_id = "aabbccddeeff"
|
||||
with pytest.raises(
|
||||
media_source.Unresolvable,
|
||||
match=f"Could not resolve media item: {invalid_id}",
|
||||
):
|
||||
await media_source.async_resolve_media(
|
||||
hass, f"media-source://ai_task/{invalid_id}", None
|
||||
)
|
@@ -1,6 +1,6 @@
|
||||
"""Test tasks for the AI Task integration."""
|
||||
|
||||
from datetime import timedelta
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -9,7 +9,12 @@ import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data
|
||||
from homeassistant.components.ai_task import (
|
||||
AITaskEntityFeature,
|
||||
ImageData,
|
||||
async_generate_data,
|
||||
async_generate_image,
|
||||
)
|
||||
from homeassistant.components.camera import Image
|
||||
from homeassistant.components.conversation import async_get_chat_log
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
@@ -232,7 +237,9 @@ async def test_generate_data_mixed_attachments(
|
||||
hass,
|
||||
dt_util.utcnow() + chat_session.CONVERSATION_TIMEOUT + timedelta(seconds=1),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done() # Need several iterations
|
||||
await hass.async_block_till_done() # because one iteration of the loop
|
||||
await hass.async_block_till_done() # simply schedules the cleanup
|
||||
|
||||
# Verify the temporary file cleaned up
|
||||
assert not camera_attachment.path.exists()
|
||||
@@ -242,3 +249,94 @@ async def test_generate_data_mixed_attachments(
|
||||
assert media_attachment.media_content_id == "media-source://media_player/video.mp4"
|
||||
assert media_attachment.mime_type == "video/mp4"
|
||||
assert media_attachment.path == Path("/media/test.mp4")
|
||||
|
||||
|
||||
async def test_generate_image(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
) -> None:
|
||||
"""Test generating image service."""
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match="AI Task entity ai_task.unknown not found"
|
||||
):
|
||||
await async_generate_image(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id="ai_task.unknown",
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
state = hass.states.get(TEST_ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
result = await async_generate_image(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
assert "image_data" not in result
|
||||
assert result["media_source_id"].startswith("media-source://ai_task/images/")
|
||||
assert result["media_source_id"].endswith("_test_task.png")
|
||||
assert result["url"].startswith("http://10.10.10.10:8123/api/ai_task/images/")
|
||||
assert result["url"].endswith("_test_task.png")
|
||||
assert result["mime_type"] == "image/png"
|
||||
assert result["model"] == "mock_model"
|
||||
assert result["revised_prompt"] == "mock_revised_prompt"
|
||||
assert result["height"] == 1024
|
||||
assert result["width"] == 1536
|
||||
|
||||
state = hass.states.get(TEST_ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state != STATE_UNKNOWN
|
||||
|
||||
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="AI Task entity ai_task.test_task_entity does not support generating images",
|
||||
):
|
||||
await async_generate_image(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
|
||||
async def test_image_cleanup(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
) -> None:
|
||||
"""Test image cache cleanup."""
|
||||
image_storage = hass.data.setdefault("ai_task_images", {})
|
||||
image_storage.clear()
|
||||
image_storage.update(
|
||||
{
|
||||
str(idx): ImageData(
|
||||
data=b"mock_image_data",
|
||||
timestamp=int(datetime.now().timestamp()),
|
||||
mime_type="image/png",
|
||||
title="Test Image",
|
||||
)
|
||||
for idx in range(20)
|
||||
}
|
||||
)
|
||||
assert len(image_storage) == 20
|
||||
|
||||
result = await async_generate_image(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
assert result["url"].split("/")[-1] in image_storage
|
||||
assert len(image_storage) == 20
|
||||
|
||||
async_fire_time_changed(hass, dt_util.utcnow() + timedelta(hours=1, seconds=1))
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert len(image_storage) == 19
|
||||
|
Reference in New Issue
Block a user