mirror of
https://github.com/home-assistant/core.git
synced 2025-08-30 09:51:37 +02:00
OpenAI ai_task image generation support (#151238)
This commit is contained in:
@@ -2,8 +2,12 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from openai.types.responses.response_output_item import ImageGenerationCall
|
||||||
|
|
||||||
from homeassistant.components import ai_task, conversation
|
from homeassistant.components import ai_task, conversation
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
@@ -12,8 +16,14 @@ from homeassistant.exceptions import HomeAssistantError
|
|||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
from homeassistant.util.json import json_loads
|
from homeassistant.util.json import json_loads
|
||||||
|
|
||||||
|
from .const import CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL, UNSUPPORTED_IMAGE_MODELS
|
||||||
from .entity import OpenAIBaseLLMEntity
|
from .entity import OpenAIBaseLLMEntity
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from homeassistant.config_entries import ConfigSubentry
|
||||||
|
|
||||||
|
from . import OpenAIConfigEntry
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -39,10 +49,16 @@ class OpenAITaskEntity(
|
|||||||
):
|
):
|
||||||
"""OpenAI AI Task entity."""
|
"""OpenAI AI Task entity."""
|
||||||
|
|
||||||
_attr_supported_features = (
|
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
|
||||||
ai_task.AITaskEntityFeature.GENERATE_DATA
|
"""Initialize the entity."""
|
||||||
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
super().__init__(entry, subentry)
|
||||||
)
|
self._attr_supported_features = (
|
||||||
|
ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||||
|
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
|
||||||
|
)
|
||||||
|
model = self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||||
|
if not model.startswith(tuple(UNSUPPORTED_IMAGE_MODELS)):
|
||||||
|
self._attr_supported_features |= ai_task.AITaskEntityFeature.GENERATE_IMAGE
|
||||||
|
|
||||||
async def _async_generate_data(
|
async def _async_generate_data(
|
||||||
self,
|
self,
|
||||||
@@ -78,3 +94,56 @@ class OpenAITaskEntity(
|
|||||||
conversation_id=chat_log.conversation_id,
|
conversation_id=chat_log.conversation_id,
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _async_generate_image(
|
||||||
|
self,
|
||||||
|
task: ai_task.GenImageTask,
|
||||||
|
chat_log: conversation.ChatLog,
|
||||||
|
) -> ai_task.GenImageTaskResult:
|
||||||
|
"""Handle a generate image task."""
|
||||||
|
await self._async_handle_chat_log(chat_log, task.name, force_image=True)
|
||||||
|
|
||||||
|
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||||
|
raise HomeAssistantError(
|
||||||
|
"Last content in chat log is not an AssistantContent"
|
||||||
|
)
|
||||||
|
|
||||||
|
image_call: ImageGenerationCall | None = None
|
||||||
|
for content in reversed(chat_log.content):
|
||||||
|
if not isinstance(content, conversation.AssistantContent):
|
||||||
|
break
|
||||||
|
if isinstance(content.native, ImageGenerationCall):
|
||||||
|
if image_call is None or image_call.result is None:
|
||||||
|
image_call = content.native
|
||||||
|
else: # Remove image data from chat log to save memory
|
||||||
|
content.native.result = None
|
||||||
|
|
||||||
|
if image_call is None or image_call.result is None:
|
||||||
|
raise HomeAssistantError("No image returned")
|
||||||
|
|
||||||
|
image_data = base64.b64decode(image_call.result)
|
||||||
|
image_call.result = None
|
||||||
|
|
||||||
|
if hasattr(image_call, "output_format") and (
|
||||||
|
output_format := image_call.output_format
|
||||||
|
):
|
||||||
|
mime_type = f"image/{output_format}"
|
||||||
|
else:
|
||||||
|
mime_type = "image/png"
|
||||||
|
|
||||||
|
if hasattr(image_call, "size") and (size := image_call.size):
|
||||||
|
width, height = tuple(size.split("x"))
|
||||||
|
else:
|
||||||
|
width, height = None, None
|
||||||
|
|
||||||
|
return ai_task.GenImageTaskResult(
|
||||||
|
image_data=image_data,
|
||||||
|
conversation_id=chat_log.conversation_id,
|
||||||
|
mime_type=mime_type,
|
||||||
|
width=int(width) if width else None,
|
||||||
|
height=int(height) if height else None,
|
||||||
|
model="gpt-image-1",
|
||||||
|
revised_prompt=image_call.revised_prompt
|
||||||
|
if hasattr(image_call, "revised_prompt")
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
@@ -60,6 +60,15 @@ UNSUPPORTED_WEB_SEARCH_MODELS: list[str] = [
|
|||||||
"o3-mini",
|
"o3-mini",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
UNSUPPORTED_IMAGE_MODELS: list[str] = [
|
||||||
|
"gpt-5",
|
||||||
|
"o3-mini",
|
||||||
|
"o4",
|
||||||
|
"o1",
|
||||||
|
"gpt-3.5",
|
||||||
|
"gpt-4-turbo",
|
||||||
|
]
|
||||||
|
|
||||||
RECOMMENDED_CONVERSATION_OPTIONS = {
|
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||||
CONF_RECOMMENDED: True,
|
CONF_RECOMMENDED: True,
|
||||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||||
|
@@ -7,7 +7,7 @@ from collections.abc import AsyncGenerator, Callable, Iterable
|
|||||||
import json
|
import json
|
||||||
from mimetypes import guess_file_type
|
from mimetypes import guess_file_type
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from openai._streaming import AsyncStream
|
from openai._streaming import AsyncStream
|
||||||
@@ -37,14 +37,20 @@ from openai.types.responses import (
|
|||||||
ResponseReasoningSummaryTextDeltaEvent,
|
ResponseReasoningSummaryTextDeltaEvent,
|
||||||
ResponseStreamEvent,
|
ResponseStreamEvent,
|
||||||
ResponseTextDeltaEvent,
|
ResponseTextDeltaEvent,
|
||||||
|
ToolChoiceTypesParam,
|
||||||
ToolParam,
|
ToolParam,
|
||||||
WebSearchToolParam,
|
WebSearchToolParam,
|
||||||
)
|
)
|
||||||
from openai.types.responses.response_create_params import ResponseCreateParamsStreaming
|
from openai.types.responses.response_create_params import ResponseCreateParamsStreaming
|
||||||
from openai.types.responses.response_input_param import FunctionCallOutput
|
from openai.types.responses.response_input_param import (
|
||||||
|
FunctionCallOutput,
|
||||||
|
ImageGenerationCall as ImageGenerationCallParam,
|
||||||
|
)
|
||||||
|
from openai.types.responses.response_output_item import ImageGenerationCall
|
||||||
from openai.types.responses.tool_param import (
|
from openai.types.responses.tool_param import (
|
||||||
CodeInterpreter,
|
CodeInterpreter,
|
||||||
CodeInterpreterContainerCodeInterpreterToolAuto,
|
CodeInterpreterContainerCodeInterpreterToolAuto,
|
||||||
|
ImageGeneration,
|
||||||
)
|
)
|
||||||
from openai.types.responses.web_search_tool_param import UserLocation
|
from openai.types.responses.web_search_tool_param import UserLocation
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@@ -230,11 +236,15 @@ def _convert_content_to_param(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
reasoning_summary = []
|
reasoning_summary = []
|
||||||
|
elif isinstance(content.native, ImageGenerationCall):
|
||||||
|
messages.append(
|
||||||
|
cast(ImageGenerationCallParam, content.native.to_dict())
|
||||||
|
)
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
async def _transform_stream(
|
async def _transform_stream( # noqa: C901 - This is complex, but better to have it in one place
|
||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
stream: AsyncStream[ResponseStreamEvent],
|
stream: AsyncStream[ResponseStreamEvent],
|
||||||
) -> AsyncGenerator[
|
) -> AsyncGenerator[
|
||||||
@@ -324,6 +334,9 @@ async def _transform_stream(
|
|||||||
"tool_result": {"status": event.item.status},
|
"tool_result": {"status": event.item.status},
|
||||||
}
|
}
|
||||||
last_role = "tool_result"
|
last_role = "tool_result"
|
||||||
|
elif isinstance(event.item, ImageGenerationCall):
|
||||||
|
yield {"native": event.item}
|
||||||
|
last_summary_index = -1 # Trigger new assistant message on next turn
|
||||||
elif isinstance(event, ResponseTextDeltaEvent):
|
elif isinstance(event, ResponseTextDeltaEvent):
|
||||||
yield {"content": event.delta}
|
yield {"content": event.delta}
|
||||||
elif isinstance(event, ResponseReasoningSummaryTextDeltaEvent):
|
elif isinstance(event, ResponseReasoningSummaryTextDeltaEvent):
|
||||||
@@ -429,6 +442,7 @@ class OpenAIBaseLLMEntity(Entity):
|
|||||||
chat_log: conversation.ChatLog,
|
chat_log: conversation.ChatLog,
|
||||||
structure_name: str | None = None,
|
structure_name: str | None = None,
|
||||||
structure: vol.Schema | None = None,
|
structure: vol.Schema | None = None,
|
||||||
|
force_image: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generate an answer for the chat log."""
|
"""Generate an answer for the chat log."""
|
||||||
options = self.subentry.data
|
options = self.subentry.data
|
||||||
@@ -495,6 +509,17 @@ class OpenAIBaseLLMEntity(Entity):
|
|||||||
)
|
)
|
||||||
model_args.setdefault("include", []).append("code_interpreter_call.outputs") # type: ignore[union-attr]
|
model_args.setdefault("include", []).append("code_interpreter_call.outputs") # type: ignore[union-attr]
|
||||||
|
|
||||||
|
if force_image:
|
||||||
|
tools.append(
|
||||||
|
ImageGeneration(
|
||||||
|
type="image_generation",
|
||||||
|
input_fidelity="high",
|
||||||
|
output_format="png",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model_args["tool_choice"] = ToolChoiceTypesParam(type="image_generation")
|
||||||
|
model_args["store"] = True # Avoid sending image data back and forth
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
model_args["tools"] = tools
|
model_args["tools"] = tools
|
||||||
|
|
||||||
|
@@ -13,6 +13,8 @@ from openai.types.responses import (
|
|||||||
ResponseFunctionCallArgumentsDoneEvent,
|
ResponseFunctionCallArgumentsDoneEvent,
|
||||||
ResponseFunctionToolCall,
|
ResponseFunctionToolCall,
|
||||||
ResponseFunctionWebSearch,
|
ResponseFunctionWebSearch,
|
||||||
|
ResponseImageGenCallCompletedEvent,
|
||||||
|
ResponseImageGenCallPartialImageEvent,
|
||||||
ResponseOutputItemAddedEvent,
|
ResponseOutputItemAddedEvent,
|
||||||
ResponseOutputItemDoneEvent,
|
ResponseOutputItemDoneEvent,
|
||||||
ResponseOutputMessage,
|
ResponseOutputMessage,
|
||||||
@@ -31,6 +33,7 @@ from openai.types.responses import (
|
|||||||
)
|
)
|
||||||
from openai.types.responses.response_code_interpreter_tool_call import OutputLogs
|
from openai.types.responses.response_code_interpreter_tool_call import OutputLogs
|
||||||
from openai.types.responses.response_function_web_search import ActionSearch
|
from openai.types.responses.response_function_web_search import ActionSearch
|
||||||
|
from openai.types.responses.response_output_item import ImageGenerationCall
|
||||||
from openai.types.responses.response_reasoning_item import Summary
|
from openai.types.responses.response_reasoning_item import Summary
|
||||||
|
|
||||||
|
|
||||||
@@ -401,3 +404,45 @@ def create_code_interpreter_item(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return events
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
def create_image_gen_call_item(
|
||||||
|
id: str, output_index: int, logs: str | None = None
|
||||||
|
) -> list[ResponseStreamEvent]:
|
||||||
|
"""Create a message item."""
|
||||||
|
return [
|
||||||
|
ResponseImageGenCallPartialImageEvent(
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
partial_image_b64="QQ==",
|
||||||
|
partial_image_index=0,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.image_generation_call.partial_image",
|
||||||
|
size="1536x1024",
|
||||||
|
quality="medium",
|
||||||
|
background="transparent",
|
||||||
|
output_format="png",
|
||||||
|
),
|
||||||
|
ResponseImageGenCallCompletedEvent(
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.image_generation_call.completed",
|
||||||
|
),
|
||||||
|
ResponseOutputItemDoneEvent(
|
||||||
|
item=ImageGenerationCall(
|
||||||
|
id=id,
|
||||||
|
result="QQ==",
|
||||||
|
status="completed",
|
||||||
|
type="image_generation_call",
|
||||||
|
background="transparent",
|
||||||
|
output_format="png",
|
||||||
|
quality="medium",
|
||||||
|
revised_prompt="Mock revised prompt.",
|
||||||
|
size="1536x1024",
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
sequence_number=0,
|
||||||
|
type="response.output_item.done",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
@@ -11,7 +11,7 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import entity_registry as er, selector
|
from homeassistant.helpers import entity_registry as er, selector
|
||||||
|
|
||||||
from . import create_message_item
|
from . import create_image_gen_call_item, create_message_item
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
@@ -206,3 +206,54 @@ async def test_generate_data_with_attachments(
|
|||||||
"type": "input_image",
|
"type": "input_image",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
async def test_generate_image(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_create_stream: AsyncMock,
|
||||||
|
entity_registry: er.EntityRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test AI Task image generation."""
|
||||||
|
entity_id = "ai_task.openai_ai_task"
|
||||||
|
|
||||||
|
# Ensure entity is linked to the subentry
|
||||||
|
entity_entry = entity_registry.async_get(entity_id)
|
||||||
|
ai_task_entry = next(
|
||||||
|
iter(
|
||||||
|
entry
|
||||||
|
for entry in mock_config_entry.subentries.values()
|
||||||
|
if entry.subentry_type == "ai_task_data"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert entity_entry is not None
|
||||||
|
assert entity_entry.config_entry_id == mock_config_entry.entry_id
|
||||||
|
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
|
||||||
|
|
||||||
|
# Mock the OpenAI response stream
|
||||||
|
mock_create_stream.return_value = [
|
||||||
|
create_image_gen_call_item(id="ig_A", output_index=0),
|
||||||
|
create_message_item(id="msg_A", text="", output_index=1),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert hass.data[ai_task.DATA_IMAGES] == {}
|
||||||
|
|
||||||
|
result = await ai_task.async_generate_image(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id="ai_task.openai_ai_task",
|
||||||
|
instructions="Generate test image",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["height"] == 1024
|
||||||
|
assert result["width"] == 1536
|
||||||
|
assert result["revised_prompt"] == "Mock revised prompt."
|
||||||
|
assert result["mime_type"] == "image/png"
|
||||||
|
assert result["model"] == "gpt-image-1"
|
||||||
|
|
||||||
|
assert len(hass.data[ai_task.DATA_IMAGES]) == 1
|
||||||
|
image_data = next(iter(hass.data[ai_task.DATA_IMAGES].values()))
|
||||||
|
assert image_data.data == b"A"
|
||||||
|
assert image_data.mime_type == "image/png"
|
||||||
|
assert image_data.title == "Mock revised prompt."
|
||||||
|
Reference in New Issue
Block a user