OpenAI ai_task image generation support (#151238)

This commit is contained in:
Denis Shulyaka
2025-08-27 21:43:27 +03:00
committed by GitHub
parent bad75222ed
commit de62991e5b
5 changed files with 207 additions and 8 deletions

View File

@@ -2,8 +2,12 @@
from __future__ import annotations
import base64
from json import JSONDecodeError
import logging
from typing import TYPE_CHECKING
from openai.types.responses.response_output_item import ImageGenerationCall
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
@@ -12,8 +16,14 @@ from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads
from .const import CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL, UNSUPPORTED_IMAGE_MODELS
from .entity import OpenAIBaseLLMEntity
if TYPE_CHECKING:
from homeassistant.config_entries import ConfigSubentry
from . import OpenAIConfigEntry
_LOGGER = logging.getLogger(__name__)
@@ -39,10 +49,16 @@ class OpenAITaskEntity(
):
"""OpenAI AI Task entity."""
_attr_supported_features = (
ai_task.AITaskEntityFeature.GENERATE_DATA
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
)
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the entity."""
super().__init__(entry, subentry)
self._attr_supported_features = (
ai_task.AITaskEntityFeature.GENERATE_DATA
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
)
model = self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
if not model.startswith(tuple(UNSUPPORTED_IMAGE_MODELS)):
self._attr_supported_features |= ai_task.AITaskEntityFeature.GENERATE_IMAGE
async def _async_generate_data(
self,
@@ -78,3 +94,56 @@ class OpenAITaskEntity(
conversation_id=chat_log.conversation_id,
data=data,
)
async def _async_generate_image(
self,
task: ai_task.GenImageTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenImageTaskResult:
"""Handle a generate image task."""
await self._async_handle_chat_log(chat_log, task.name, force_image=True)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
raise HomeAssistantError(
"Last content in chat log is not an AssistantContent"
)
image_call: ImageGenerationCall | None = None
for content in reversed(chat_log.content):
if not isinstance(content, conversation.AssistantContent):
break
if isinstance(content.native, ImageGenerationCall):
if image_call is None or image_call.result is None:
image_call = content.native
else: # Remove image data from chat log to save memory
content.native.result = None
if image_call is None or image_call.result is None:
raise HomeAssistantError("No image returned")
image_data = base64.b64decode(image_call.result)
image_call.result = None
if hasattr(image_call, "output_format") and (
output_format := image_call.output_format
):
mime_type = f"image/{output_format}"
else:
mime_type = "image/png"
if hasattr(image_call, "size") and (size := image_call.size):
width, height = tuple(size.split("x"))
else:
width, height = None, None
return ai_task.GenImageTaskResult(
image_data=image_data,
conversation_id=chat_log.conversation_id,
mime_type=mime_type,
width=int(width) if width else None,
height=int(height) if height else None,
model="gpt-image-1",
revised_prompt=image_call.revised_prompt
if hasattr(image_call, "revised_prompt")
else None,
)

View File

@@ -60,6 +60,15 @@ UNSUPPORTED_WEB_SEARCH_MODELS: list[str] = [
"o3-mini",
]
UNSUPPORTED_IMAGE_MODELS: list[str] = [
"gpt-5",
"o3-mini",
"o4",
"o1",
"gpt-3.5",
"gpt-4-turbo",
]
RECOMMENDED_CONVERSATION_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],

View File

@@ -7,7 +7,7 @@ from collections.abc import AsyncGenerator, Callable, Iterable
import json
from mimetypes import guess_file_type
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, cast
import openai
from openai._streaming import AsyncStream
@@ -37,14 +37,20 @@ from openai.types.responses import (
ResponseReasoningSummaryTextDeltaEvent,
ResponseStreamEvent,
ResponseTextDeltaEvent,
ToolChoiceTypesParam,
ToolParam,
WebSearchToolParam,
)
from openai.types.responses.response_create_params import ResponseCreateParamsStreaming
from openai.types.responses.response_input_param import FunctionCallOutput
from openai.types.responses.response_input_param import (
FunctionCallOutput,
ImageGenerationCall as ImageGenerationCallParam,
)
from openai.types.responses.response_output_item import ImageGenerationCall
from openai.types.responses.tool_param import (
CodeInterpreter,
CodeInterpreterContainerCodeInterpreterToolAuto,
ImageGeneration,
)
from openai.types.responses.web_search_tool_param import UserLocation
import voluptuous as vol
@@ -230,11 +236,15 @@ def _convert_content_to_param(
)
)
reasoning_summary = []
elif isinstance(content.native, ImageGenerationCall):
messages.append(
cast(ImageGenerationCallParam, content.native.to_dict())
)
return messages
async def _transform_stream(
async def _transform_stream( # noqa: C901 - This is complex, but better to have it in one place
chat_log: conversation.ChatLog,
stream: AsyncStream[ResponseStreamEvent],
) -> AsyncGenerator[
@@ -324,6 +334,9 @@ async def _transform_stream(
"tool_result": {"status": event.item.status},
}
last_role = "tool_result"
elif isinstance(event.item, ImageGenerationCall):
yield {"native": event.item}
last_summary_index = -1 # Trigger new assistant message on next turn
elif isinstance(event, ResponseTextDeltaEvent):
yield {"content": event.delta}
elif isinstance(event, ResponseReasoningSummaryTextDeltaEvent):
@@ -429,6 +442,7 @@ class OpenAIBaseLLMEntity(Entity):
chat_log: conversation.ChatLog,
structure_name: str | None = None,
structure: vol.Schema | None = None,
force_image: bool = False,
) -> None:
"""Generate an answer for the chat log."""
options = self.subentry.data
@@ -495,6 +509,17 @@ class OpenAIBaseLLMEntity(Entity):
)
model_args.setdefault("include", []).append("code_interpreter_call.outputs") # type: ignore[union-attr]
if force_image:
tools.append(
ImageGeneration(
type="image_generation",
input_fidelity="high",
output_format="png",
)
)
model_args["tool_choice"] = ToolChoiceTypesParam(type="image_generation")
model_args["store"] = True # Avoid sending image data back and forth
if tools:
model_args["tools"] = tools

View File

@@ -13,6 +13,8 @@ from openai.types.responses import (
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
ResponseImageGenCallCompletedEvent,
ResponseImageGenCallPartialImageEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
@@ -31,6 +33,7 @@ from openai.types.responses import (
)
from openai.types.responses.response_code_interpreter_tool_call import OutputLogs
from openai.types.responses.response_function_web_search import ActionSearch
from openai.types.responses.response_output_item import ImageGenerationCall
from openai.types.responses.response_reasoning_item import Summary
@@ -401,3 +404,45 @@ def create_code_interpreter_item(
)
return events
def create_image_gen_call_item(
id: str, output_index: int, logs: str | None = None
) -> list[ResponseStreamEvent]:
"""Create a message item."""
return [
ResponseImageGenCallPartialImageEvent(
item_id=id,
output_index=output_index,
partial_image_b64="QQ==",
partial_image_index=0,
sequence_number=0,
type="response.image_generation_call.partial_image",
size="1536x1024",
quality="medium",
background="transparent",
output_format="png",
),
ResponseImageGenCallCompletedEvent(
item_id=id,
output_index=output_index,
sequence_number=0,
type="response.image_generation_call.completed",
),
ResponseOutputItemDoneEvent(
item=ImageGenerationCall(
id=id,
result="QQ==",
status="completed",
type="image_generation_call",
background="transparent",
output_format="png",
quality="medium",
revised_prompt="Mock revised prompt.",
size="1536x1024",
),
output_index=output_index,
sequence_number=0,
type="response.output_item.done",
),
]

View File

@@ -11,7 +11,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er, selector
from . import create_message_item
from . import create_image_gen_call_item, create_message_item
from tests.common import MockConfigEntry
@@ -206,3 +206,54 @@ async def test_generate_data_with_attachments(
"type": "input_image",
},
]
@pytest.mark.usefixtures("mock_init_component")
async def test_generate_image(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_create_stream: AsyncMock,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task image generation."""
entity_id = "ai_task.openai_ai_task"
# Ensure entity is linked to the subentry
entity_entry = entity_registry.async_get(entity_id)
ai_task_entry = next(
iter(
entry
for entry in mock_config_entry.subentries.values()
if entry.subentry_type == "ai_task_data"
)
)
assert entity_entry is not None
assert entity_entry.config_entry_id == mock_config_entry.entry_id
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
# Mock the OpenAI response stream
mock_create_stream.return_value = [
create_image_gen_call_item(id="ig_A", output_index=0),
create_message_item(id="msg_A", text="", output_index=1),
]
assert hass.data[ai_task.DATA_IMAGES] == {}
result = await ai_task.async_generate_image(
hass,
task_name="Test Task",
entity_id="ai_task.openai_ai_task",
instructions="Generate test image",
)
assert result["height"] == 1024
assert result["width"] == 1536
assert result["revised_prompt"] == "Mock revised prompt."
assert result["mime_type"] == "image/png"
assert result["model"] == "gpt-image-1"
assert len(hass.data[ai_task.DATA_IMAGES]) == 1
image_data = next(iter(hass.data[ai_task.DATA_IMAGES].values()))
assert image_data.data == b"A"
assert image_data.mime_type == "image/png"
assert image_data.title == "Mock revised prompt."