diff --git a/homeassistant/components/openai_conversation/ai_task.py b/homeassistant/components/openai_conversation/ai_task.py index 5fc700a73ad..bc05671e48f 100644 --- a/homeassistant/components/openai_conversation/ai_task.py +++ b/homeassistant/components/openai_conversation/ai_task.py @@ -2,8 +2,12 @@ from __future__ import annotations +import base64 from json import JSONDecodeError import logging +from typing import TYPE_CHECKING + +from openai.types.responses.response_output_item import ImageGenerationCall from homeassistant.components import ai_task, conversation from homeassistant.config_entries import ConfigEntry @@ -12,8 +16,14 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.util.json import json_loads +from .const import CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL, UNSUPPORTED_IMAGE_MODELS from .entity import OpenAIBaseLLMEntity +if TYPE_CHECKING: + from homeassistant.config_entries import ConfigSubentry + + from . import OpenAIConfigEntry + _LOGGER = logging.getLogger(__name__) @@ -39,10 +49,16 @@ class OpenAITaskEntity( ): """OpenAI AI Task entity.""" - _attr_supported_features = ( - ai_task.AITaskEntityFeature.GENERATE_DATA - | ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS - ) + def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None: + """Initialize the entity.""" + super().__init__(entry, subentry) + self._attr_supported_features = ( + ai_task.AITaskEntityFeature.GENERATE_DATA + | ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS + ) + model = self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) + if not model.startswith(tuple(UNSUPPORTED_IMAGE_MODELS)): + self._attr_supported_features |= ai_task.AITaskEntityFeature.GENERATE_IMAGE async def _async_generate_data( self, @@ -78,3 +94,56 @@ class OpenAITaskEntity( conversation_id=chat_log.conversation_id, data=data, ) + + async def _async_generate_image( + self, + task: ai_task.GenImageTask, + chat_log: conversation.ChatLog, + ) -> ai_task.GenImageTaskResult: + """Handle a generate image task.""" + await self._async_handle_chat_log(chat_log, task.name, force_image=True) + + if not isinstance(chat_log.content[-1], conversation.AssistantContent): + raise HomeAssistantError( + "Last content in chat log is not an AssistantContent" + ) + + image_call: ImageGenerationCall | None = None + for content in reversed(chat_log.content): + if not isinstance(content, conversation.AssistantContent): + break + if isinstance(content.native, ImageGenerationCall): + if image_call is None or image_call.result is None: + image_call = content.native + else: # Remove image data from chat log to save memory + content.native.result = None + + if image_call is None or image_call.result is None: + raise HomeAssistantError("No image returned") + + image_data = base64.b64decode(image_call.result) + image_call.result = None + + if hasattr(image_call, "output_format") and ( + output_format := image_call.output_format + ): + mime_type = f"image/{output_format}" + else: + mime_type = "image/png" + + if hasattr(image_call, "size") and (size := image_call.size): + width, height = tuple(size.split("x")) + else: + width, height = None, None + + return ai_task.GenImageTaskResult( + image_data=image_data, + conversation_id=chat_log.conversation_id, + mime_type=mime_type, + width=int(width) if width else None, + height=int(height) if height else None, + model="gpt-image-1", + revised_prompt=image_call.revised_prompt + if hasattr(image_call, "revised_prompt") + else None, + ) diff --git a/homeassistant/components/openai_conversation/const.py b/homeassistant/components/openai_conversation/const.py index 2fd18913207..fda862e1dbe 100644 --- a/homeassistant/components/openai_conversation/const.py +++ b/homeassistant/components/openai_conversation/const.py @@ -60,6 +60,15 @@ UNSUPPORTED_WEB_SEARCH_MODELS: list[str] = [ "o3-mini", ] +UNSUPPORTED_IMAGE_MODELS: list[str] = [ + "gpt-5", + "o3-mini", + "o4", + "o1", + "gpt-3.5", + "gpt-4-turbo", +] + RECOMMENDED_CONVERSATION_OPTIONS = { CONF_RECOMMENDED: True, CONF_LLM_HASS_API: [llm.LLM_API_ASSIST], diff --git a/homeassistant/components/openai_conversation/entity.py b/homeassistant/components/openai_conversation/entity.py index 44d833c8e71..31e31a72915 100644 --- a/homeassistant/components/openai_conversation/entity.py +++ b/homeassistant/components/openai_conversation/entity.py @@ -7,7 +7,7 @@ from collections.abc import AsyncGenerator, Callable, Iterable import json from mimetypes import guess_file_type from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast import openai from openai._streaming import AsyncStream @@ -37,14 +37,20 @@ from openai.types.responses import ( ResponseReasoningSummaryTextDeltaEvent, ResponseStreamEvent, ResponseTextDeltaEvent, + ToolChoiceTypesParam, ToolParam, WebSearchToolParam, ) from openai.types.responses.response_create_params import ResponseCreateParamsStreaming -from openai.types.responses.response_input_param import FunctionCallOutput +from openai.types.responses.response_input_param import ( + FunctionCallOutput, + ImageGenerationCall as ImageGenerationCallParam, +) +from openai.types.responses.response_output_item import ImageGenerationCall from openai.types.responses.tool_param import ( CodeInterpreter, CodeInterpreterContainerCodeInterpreterToolAuto, + ImageGeneration, ) from openai.types.responses.web_search_tool_param import UserLocation import voluptuous as vol @@ -230,11 +236,15 @@ def _convert_content_to_param( ) ) reasoning_summary = [] + elif isinstance(content.native, ImageGenerationCall): + messages.append( + cast(ImageGenerationCallParam, content.native.to_dict()) + ) return messages -async def _transform_stream( +async def _transform_stream( # noqa: C901 - This is complex, but better to have it in one place chat_log: conversation.ChatLog, stream: AsyncStream[ResponseStreamEvent], ) -> AsyncGenerator[ @@ -324,6 +334,9 @@ async def _transform_stream( "tool_result": {"status": event.item.status}, } last_role = "tool_result" + elif isinstance(event.item, ImageGenerationCall): + yield {"native": event.item} + last_summary_index = -1 # Trigger new assistant message on next turn elif isinstance(event, ResponseTextDeltaEvent): yield {"content": event.delta} elif isinstance(event, ResponseReasoningSummaryTextDeltaEvent): @@ -429,6 +442,7 @@ class OpenAIBaseLLMEntity(Entity): chat_log: conversation.ChatLog, structure_name: str | None = None, structure: vol.Schema | None = None, + force_image: bool = False, ) -> None: """Generate an answer for the chat log.""" options = self.subentry.data @@ -495,6 +509,17 @@ class OpenAIBaseLLMEntity(Entity): ) model_args.setdefault("include", []).append("code_interpreter_call.outputs") # type: ignore[union-attr] + if force_image: + tools.append( + ImageGeneration( + type="image_generation", + input_fidelity="high", + output_format="png", + ) + ) + model_args["tool_choice"] = ToolChoiceTypesParam(type="image_generation") + model_args["store"] = True # Avoid sending image data back and forth + if tools: model_args["tools"] = tools diff --git a/tests/components/openai_conversation/__init__.py b/tests/components/openai_conversation/__init__.py index e8effca3bc5..fb19236034f 100644 --- a/tests/components/openai_conversation/__init__.py +++ b/tests/components/openai_conversation/__init__.py @@ -13,6 +13,8 @@ from openai.types.responses import ( ResponseFunctionCallArgumentsDoneEvent, ResponseFunctionToolCall, ResponseFunctionWebSearch, + ResponseImageGenCallCompletedEvent, + ResponseImageGenCallPartialImageEvent, ResponseOutputItemAddedEvent, ResponseOutputItemDoneEvent, ResponseOutputMessage, @@ -31,6 +33,7 @@ from openai.types.responses import ( ) from openai.types.responses.response_code_interpreter_tool_call import OutputLogs from openai.types.responses.response_function_web_search import ActionSearch +from openai.types.responses.response_output_item import ImageGenerationCall from openai.types.responses.response_reasoning_item import Summary @@ -401,3 +404,45 @@ def create_code_interpreter_item( ) return events + + +def create_image_gen_call_item( + id: str, output_index: int, logs: str | None = None +) -> list[ResponseStreamEvent]: + """Create a message item.""" + return [ + ResponseImageGenCallPartialImageEvent( + item_id=id, + output_index=output_index, + partial_image_b64="QQ==", + partial_image_index=0, + sequence_number=0, + type="response.image_generation_call.partial_image", + size="1536x1024", + quality="medium", + background="transparent", + output_format="png", + ), + ResponseImageGenCallCompletedEvent( + item_id=id, + output_index=output_index, + sequence_number=0, + type="response.image_generation_call.completed", + ), + ResponseOutputItemDoneEvent( + item=ImageGenerationCall( + id=id, + result="QQ==", + status="completed", + type="image_generation_call", + background="transparent", + output_format="png", + quality="medium", + revised_prompt="Mock revised prompt.", + size="1536x1024", + ), + output_index=output_index, + sequence_number=0, + type="response.output_item.done", + ), + ] diff --git a/tests/components/openai_conversation/test_ai_task.py b/tests/components/openai_conversation/test_ai_task.py index 14e3056c0e2..d5792ea4899 100644 --- a/tests/components/openai_conversation/test_ai_task.py +++ b/tests/components/openai_conversation/test_ai_task.py @@ -11,7 +11,7 @@ from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import entity_registry as er, selector -from . import create_message_item +from . import create_image_gen_call_item, create_message_item from tests.common import MockConfigEntry @@ -206,3 +206,54 @@ async def test_generate_data_with_attachments( "type": "input_image", }, ] + + +@pytest.mark.usefixtures("mock_init_component") +async def test_generate_image( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_create_stream: AsyncMock, + entity_registry: er.EntityRegistry, +) -> None: + """Test AI Task image generation.""" + entity_id = "ai_task.openai_ai_task" + + # Ensure entity is linked to the subentry + entity_entry = entity_registry.async_get(entity_id) + ai_task_entry = next( + iter( + entry + for entry in mock_config_entry.subentries.values() + if entry.subentry_type == "ai_task_data" + ) + ) + assert entity_entry is not None + assert entity_entry.config_entry_id == mock_config_entry.entry_id + assert entity_entry.config_subentry_id == ai_task_entry.subentry_id + + # Mock the OpenAI response stream + mock_create_stream.return_value = [ + create_image_gen_call_item(id="ig_A", output_index=0), + create_message_item(id="msg_A", text="", output_index=1), + ] + + assert hass.data[ai_task.DATA_IMAGES] == {} + + result = await ai_task.async_generate_image( + hass, + task_name="Test Task", + entity_id="ai_task.openai_ai_task", + instructions="Generate test image", + ) + + assert result["height"] == 1024 + assert result["width"] == 1536 + assert result["revised_prompt"] == "Mock revised prompt." + assert result["mime_type"] == "image/png" + assert result["model"] == "gpt-image-1" + + assert len(hass.data[ai_task.DATA_IMAGES]) == 1 + image_data = next(iter(hass.data[ai_task.DATA_IMAGES].values())) + assert image_data.data == b"A" + assert image_data.mime_type == "image/png" + assert image_data.title == "Mock revised prompt."