mirror of
https://github.com/home-assistant/core.git
synced 2025-08-30 01:42:21 +02:00
OpenAI ai_task image generation support (#151238)
This commit is contained in:
@@ -2,8 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from json import JSONDecodeError
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openai.types.responses.response_output_item import ImageGenerationCall
|
||||
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@@ -12,8 +16,14 @@ from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
from homeassistant.util.json import json_loads
|
||||
|
||||
from .const import CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL, UNSUPPORTED_IMAGE_MODELS
|
||||
from .entity import OpenAIBaseLLMEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.config_entries import ConfigSubentry
|
||||
|
||||
from . import OpenAIConfigEntry
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -39,10 +49,16 @@ class OpenAITaskEntity(
|
||||
):
|
||||
"""OpenAI AI Task entity."""
|
||||
|
||||
_attr_supported_features = (
|
||||
ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||
)
|
||||
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
|
||||
"""Initialize the entity."""
|
||||
super().__init__(entry, subentry)
|
||||
self._attr_supported_features = (
|
||||
ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||
)
|
||||
model = self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
if not model.startswith(tuple(UNSUPPORTED_IMAGE_MODELS)):
|
||||
self._attr_supported_features |= ai_task.AITaskEntityFeature.GENERATE_IMAGE
|
||||
|
||||
async def _async_generate_data(
|
||||
self,
|
||||
@@ -78,3 +94,56 @@ class OpenAITaskEntity(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
data=data,
|
||||
)
|
||||
|
||||
async def _async_generate_image(
|
||||
self,
|
||||
task: ai_task.GenImageTask,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenImageTaskResult:
|
||||
"""Handle a generate image task."""
|
||||
await self._async_handle_chat_log(chat_log, task.name, force_image=True)
|
||||
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
raise HomeAssistantError(
|
||||
"Last content in chat log is not an AssistantContent"
|
||||
)
|
||||
|
||||
image_call: ImageGenerationCall | None = None
|
||||
for content in reversed(chat_log.content):
|
||||
if not isinstance(content, conversation.AssistantContent):
|
||||
break
|
||||
if isinstance(content.native, ImageGenerationCall):
|
||||
if image_call is None or image_call.result is None:
|
||||
image_call = content.native
|
||||
else: # Remove image data from chat log to save memory
|
||||
content.native.result = None
|
||||
|
||||
if image_call is None or image_call.result is None:
|
||||
raise HomeAssistantError("No image returned")
|
||||
|
||||
image_data = base64.b64decode(image_call.result)
|
||||
image_call.result = None
|
||||
|
||||
if hasattr(image_call, "output_format") and (
|
||||
output_format := image_call.output_format
|
||||
):
|
||||
mime_type = f"image/{output_format}"
|
||||
else:
|
||||
mime_type = "image/png"
|
||||
|
||||
if hasattr(image_call, "size") and (size := image_call.size):
|
||||
width, height = tuple(size.split("x"))
|
||||
else:
|
||||
width, height = None, None
|
||||
|
||||
return ai_task.GenImageTaskResult(
|
||||
image_data=image_data,
|
||||
conversation_id=chat_log.conversation_id,
|
||||
mime_type=mime_type,
|
||||
width=int(width) if width else None,
|
||||
height=int(height) if height else None,
|
||||
model="gpt-image-1",
|
||||
revised_prompt=image_call.revised_prompt
|
||||
if hasattr(image_call, "revised_prompt")
|
||||
else None,
|
||||
)
|
||||
|
@@ -60,6 +60,15 @@ UNSUPPORTED_WEB_SEARCH_MODELS: list[str] = [
|
||||
"o3-mini",
|
||||
]
|
||||
|
||||
UNSUPPORTED_IMAGE_MODELS: list[str] = [
|
||||
"gpt-5",
|
||||
"o3-mini",
|
||||
"o4",
|
||||
"o1",
|
||||
"gpt-3.5",
|
||||
"gpt-4-turbo",
|
||||
]
|
||||
|
||||
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||
|
@@ -7,7 +7,7 @@ from collections.abc import AsyncGenerator, Callable, Iterable
|
||||
import json
|
||||
from mimetypes import guess_file_type
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
import openai
|
||||
from openai._streaming import AsyncStream
|
||||
@@ -37,14 +37,20 @@ from openai.types.responses import (
|
||||
ResponseReasoningSummaryTextDeltaEvent,
|
||||
ResponseStreamEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ToolChoiceTypesParam,
|
||||
ToolParam,
|
||||
WebSearchToolParam,
|
||||
)
|
||||
from openai.types.responses.response_create_params import ResponseCreateParamsStreaming
|
||||
from openai.types.responses.response_input_param import FunctionCallOutput
|
||||
from openai.types.responses.response_input_param import (
|
||||
FunctionCallOutput,
|
||||
ImageGenerationCall as ImageGenerationCallParam,
|
||||
)
|
||||
from openai.types.responses.response_output_item import ImageGenerationCall
|
||||
from openai.types.responses.tool_param import (
|
||||
CodeInterpreter,
|
||||
CodeInterpreterContainerCodeInterpreterToolAuto,
|
||||
ImageGeneration,
|
||||
)
|
||||
from openai.types.responses.web_search_tool_param import UserLocation
|
||||
import voluptuous as vol
|
||||
@@ -230,11 +236,15 @@ def _convert_content_to_param(
|
||||
)
|
||||
)
|
||||
reasoning_summary = []
|
||||
elif isinstance(content.native, ImageGenerationCall):
|
||||
messages.append(
|
||||
cast(ImageGenerationCallParam, content.native.to_dict())
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
async def _transform_stream(
|
||||
async def _transform_stream( # noqa: C901 - This is complex, but better to have it in one place
|
||||
chat_log: conversation.ChatLog,
|
||||
stream: AsyncStream[ResponseStreamEvent],
|
||||
) -> AsyncGenerator[
|
||||
@@ -324,6 +334,9 @@ async def _transform_stream(
|
||||
"tool_result": {"status": event.item.status},
|
||||
}
|
||||
last_role = "tool_result"
|
||||
elif isinstance(event.item, ImageGenerationCall):
|
||||
yield {"native": event.item}
|
||||
last_summary_index = -1 # Trigger new assistant message on next turn
|
||||
elif isinstance(event, ResponseTextDeltaEvent):
|
||||
yield {"content": event.delta}
|
||||
elif isinstance(event, ResponseReasoningSummaryTextDeltaEvent):
|
||||
@@ -429,6 +442,7 @@ class OpenAIBaseLLMEntity(Entity):
|
||||
chat_log: conversation.ChatLog,
|
||||
structure_name: str | None = None,
|
||||
structure: vol.Schema | None = None,
|
||||
force_image: bool = False,
|
||||
) -> None:
|
||||
"""Generate an answer for the chat log."""
|
||||
options = self.subentry.data
|
||||
@@ -495,6 +509,17 @@ class OpenAIBaseLLMEntity(Entity):
|
||||
)
|
||||
model_args.setdefault("include", []).append("code_interpreter_call.outputs") # type: ignore[union-attr]
|
||||
|
||||
if force_image:
|
||||
tools.append(
|
||||
ImageGeneration(
|
||||
type="image_generation",
|
||||
input_fidelity="high",
|
||||
output_format="png",
|
||||
)
|
||||
)
|
||||
model_args["tool_choice"] = ToolChoiceTypesParam(type="image_generation")
|
||||
model_args["store"] = True # Avoid sending image data back and forth
|
||||
|
||||
if tools:
|
||||
model_args["tools"] = tools
|
||||
|
||||
|
@@ -13,6 +13,8 @@ from openai.types.responses import (
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseFunctionWebSearch,
|
||||
ResponseImageGenCallCompletedEvent,
|
||||
ResponseImageGenCallPartialImageEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputMessage,
|
||||
@@ -31,6 +33,7 @@ from openai.types.responses import (
|
||||
)
|
||||
from openai.types.responses.response_code_interpreter_tool_call import OutputLogs
|
||||
from openai.types.responses.response_function_web_search import ActionSearch
|
||||
from openai.types.responses.response_output_item import ImageGenerationCall
|
||||
from openai.types.responses.response_reasoning_item import Summary
|
||||
|
||||
|
||||
@@ -401,3 +404,45 @@ def create_code_interpreter_item(
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def create_image_gen_call_item(
|
||||
id: str, output_index: int, logs: str | None = None
|
||||
) -> list[ResponseStreamEvent]:
|
||||
"""Create a message item."""
|
||||
return [
|
||||
ResponseImageGenCallPartialImageEvent(
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
partial_image_b64="QQ==",
|
||||
partial_image_index=0,
|
||||
sequence_number=0,
|
||||
type="response.image_generation_call.partial_image",
|
||||
size="1536x1024",
|
||||
quality="medium",
|
||||
background="transparent",
|
||||
output_format="png",
|
||||
),
|
||||
ResponseImageGenCallCompletedEvent(
|
||||
item_id=id,
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.image_generation_call.completed",
|
||||
),
|
||||
ResponseOutputItemDoneEvent(
|
||||
item=ImageGenerationCall(
|
||||
id=id,
|
||||
result="QQ==",
|
||||
status="completed",
|
||||
type="image_generation_call",
|
||||
background="transparent",
|
||||
output_format="png",
|
||||
quality="medium",
|
||||
revised_prompt="Mock revised prompt.",
|
||||
size="1536x1024",
|
||||
),
|
||||
output_index=output_index,
|
||||
sequence_number=0,
|
||||
type="response.output_item.done",
|
||||
),
|
||||
]
|
||||
|
@@ -11,7 +11,7 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import entity_registry as er, selector
|
||||
|
||||
from . import create_message_item
|
||||
from . import create_image_gen_call_item, create_message_item
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
@@ -206,3 +206,54 @@ async def test_generate_data_with_attachments(
|
||||
"type": "input_image",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_generate_image(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_create_stream: AsyncMock,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test AI Task image generation."""
|
||||
entity_id = "ai_task.openai_ai_task"
|
||||
|
||||
# Ensure entity is linked to the subentry
|
||||
entity_entry = entity_registry.async_get(entity_id)
|
||||
ai_task_entry = next(
|
||||
iter(
|
||||
entry
|
||||
for entry in mock_config_entry.subentries.values()
|
||||
if entry.subentry_type == "ai_task_data"
|
||||
)
|
||||
)
|
||||
assert entity_entry is not None
|
||||
assert entity_entry.config_entry_id == mock_config_entry.entry_id
|
||||
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
|
||||
|
||||
# Mock the OpenAI response stream
|
||||
mock_create_stream.return_value = [
|
||||
create_image_gen_call_item(id="ig_A", output_index=0),
|
||||
create_message_item(id="msg_A", text="", output_index=1),
|
||||
]
|
||||
|
||||
assert hass.data[ai_task.DATA_IMAGES] == {}
|
||||
|
||||
result = await ai_task.async_generate_image(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id="ai_task.openai_ai_task",
|
||||
instructions="Generate test image",
|
||||
)
|
||||
|
||||
assert result["height"] == 1024
|
||||
assert result["width"] == 1536
|
||||
assert result["revised_prompt"] == "Mock revised prompt."
|
||||
assert result["mime_type"] == "image/png"
|
||||
assert result["model"] == "gpt-image-1"
|
||||
|
||||
assert len(hass.data[ai_task.DATA_IMAGES]) == 1
|
||||
image_data = next(iter(hass.data[ai_task.DATA_IMAGES].values()))
|
||||
assert image_data.data == b"A"
|
||||
assert image_data.mime_type == "image/png"
|
||||
assert image_data.title == "Mock revised prompt."
|
||||
|
Reference in New Issue
Block a user