From 2029996d19c2ef99f184b32783bd57b32be3145b Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 22 Jun 2025 03:44:43 +0000 Subject: [PATCH] Add AI Task platform to OpenAI --- .../openai_conversation/__init__.py | 25 +- .../components/openai_conversation/ai_task.py | 62 ++++ .../openai_conversation/config_flow.py | 77 +++-- .../components/openai_conversation/const.py | 14 + .../openai_conversation/conversation.py | 309 +---------------- .../components/openai_conversation/entity.py | 313 ++++++++++++++++++ .../openai_conversation/strings.json | 45 +++ homeassistant/config_entries.py | 5 + .../components/openai_conversation/common.py | 87 +++++ .../openai_conversation/conftest.py | 117 ++++++- .../openai_conversation/test_ai_task.py | 33 ++ .../openai_conversation/test_config_flow.py | 55 ++- .../openai_conversation/test_conversation.py | 177 +--------- 13 files changed, 799 insertions(+), 520 deletions(-) create mode 100644 homeassistant/components/openai_conversation/ai_task.py create mode 100644 homeassistant/components/openai_conversation/entity.py create mode 100644 tests/components/openai_conversation/common.py create mode 100644 tests/components/openai_conversation/test_ai_task.py diff --git a/homeassistant/components/openai_conversation/__init__.py b/homeassistant/components/openai_conversation/__init__.py index 9d3553fef94..7252c21789a 100644 --- a/homeassistant/components/openai_conversation/__init__.py +++ b/homeassistant/components/openai_conversation/__init__.py @@ -5,6 +5,7 @@ from __future__ import annotations import base64 from mimetypes import guess_file_type from pathlib import Path +from types import MappingProxyType import openai from openai.types.images_response import ImagesResponse @@ -48,9 +49,11 @@ from .const import ( CONF_REASONING_EFFORT, CONF_TEMPERATURE, CONF_TOP_P, + DEFAULT_AI_TASK_NAME, DEFAULT_CONVERSATION_NAME, DOMAIN, LOGGER, + RECOMMENDED_AI_TASK_OPTIONS, RECOMMENDED_CHAT_MODEL, RECOMMENDED_MAX_TOKENS, RECOMMENDED_REASONING_EFFORT, @@ -61,7 +64,10 @@ from .const import ( SERVICE_GENERATE_IMAGE = "generate_image" SERVICE_GENERATE_CONTENT = "generate_content" -PLATFORMS = (Platform.CONVERSATION,) +PLATFORMS = ( + Platform.AI_TASK, + Platform.CONVERSATION, +) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient] @@ -295,7 +301,7 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> if entry.version == 1: # Migrate from version 1 to version 2 # Move conversation-specific options to a subentry - subentry = ConfigSubentry( + conversation_subentry = ConfigSubentry( data=entry.options, subentry_type="conversation", title=DEFAULT_CONVERSATION_NAME, @@ -303,7 +309,16 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> ) hass.config_entries.async_add_subentry( entry, - subentry, + conversation_subentry, + ) + hass.config_entries.async_add_subentry( + entry, + ConfigSubentry( + data=MappingProxyType(RECOMMENDED_AI_TASK_OPTIONS), + subentry_type="ai_task", + title=DEFAULT_AI_TASK_NAME, + unique_id=None, + ), ) # Migrate conversation entity to be linked to subentry @@ -312,8 +327,8 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> if entity_entry.domain == Platform.CONVERSATION: ent_reg.async_update_entity( entity_entry.entity_id, - config_subentry_id=subentry.subentry_id, - new_unique_id=subentry.subentry_id, + config_subentry_id=conversation_subentry.subentry_id, + new_unique_id=conversation_subentry.subentry_id, ) break diff --git a/homeassistant/components/openai_conversation/ai_task.py b/homeassistant/components/openai_conversation/ai_task.py new file mode 100644 index 00000000000..883d458c5cb --- /dev/null +++ b/homeassistant/components/openai_conversation/ai_task.py @@ -0,0 +1,62 @@ +"""AI Task integration for OpenAI Conversation.""" + +from __future__ import annotations + +from homeassistant.components import ai_task, conversation +from homeassistant.config_entries import ConfigEntry, ConfigSubentry +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback + +from . import OpenAIConfigEntry +from .const import DEFAULT_AI_TASK_NAME, LOGGER +from .entity import OpenAILLMBaseEntity + +ERROR_GETTING_RESPONSE = "Sorry, I had a problem getting a response from OpenAI." + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddConfigEntryEntitiesCallback, +) -> None: + """Set up AI Task entities.""" + for subentry in config_entry.subentries.values(): + if subentry.subentry_type != "ai_task": + continue + + async_add_entities( + [OpenAILLMTaskEntity(config_entry, subentry)], + config_subentry_id=subentry.subentry_id, + ) + + +class OpenAILLMTaskEntity(ai_task.AITaskEntity, OpenAILLMBaseEntity): + """OpenAI AI Task entity.""" + + _attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_TEXT + + def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None: + """Initialize the agent.""" + super().__init__(entry, subentry) + self._attr_name = subentry.title or DEFAULT_AI_TASK_NAME + + async def _async_generate_text( + self, + task: ai_task.GenTextTask, + chat_log: conversation.ChatLog, + ) -> ai_task.GenTextTaskResult: + """Handle a generate text task.""" + await self._async_handle_chat_log(chat_log) + + if not isinstance(chat_log.content[-1], conversation.AssistantContent): + LOGGER.error( + "Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response", + chat_log.content[-1], + ) + raise HomeAssistantError(ERROR_GETTING_RESPONSE) + + return ai_task.GenTextTaskResult( + conversation_id=chat_log.conversation_id, + text=chat_log.content[-1].content or "", + ) diff --git a/homeassistant/components/openai_conversation/config_flow.py b/homeassistant/components/openai_conversation/config_flow.py index f80ea55ec0f..91ce63d54cf 100644 --- a/homeassistant/components/openai_conversation/config_flow.py +++ b/homeassistant/components/openai_conversation/config_flow.py @@ -54,9 +54,12 @@ from .const import ( CONF_WEB_SEARCH_REGION, CONF_WEB_SEARCH_TIMEZONE, CONF_WEB_SEARCH_USER_LOCATION, + DEFAULT_AI_TASK_NAME, DEFAULT_CONVERSATION_NAME, DOMAIN, + RECOMMENDED_AI_TASK_OPTIONS, RECOMMENDED_CHAT_MODEL, + RECOMMENDED_CONVERSATION_OPTIONS, RECOMMENDED_MAX_TOKENS, RECOMMENDED_REASONING_EFFORT, RECOMMENDED_TEMPERATURE, @@ -76,12 +79,6 @@ STEP_USER_DATA_SCHEMA = vol.Schema( } ) -RECOMMENDED_OPTIONS = { - CONF_RECOMMENDED: True, - CONF_LLM_HASS_API: [llm.LLM_API_ASSIST], - CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT, -} - async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: """Validate the user input allows us to connect. @@ -126,10 +123,16 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN): subentries=[ { "subentry_type": "conversation", - "data": RECOMMENDED_OPTIONS, + "data": RECOMMENDED_CONVERSATION_OPTIONS, "title": DEFAULT_CONVERSATION_NAME, "unique_id": None, - } + }, + { + "subentry_type": "ai_task", + "data": RECOMMENDED_AI_TASK_OPTIONS, + "title": DEFAULT_AI_TASK_NAME, + "unique_id": None, + }, ], ) @@ -143,10 +146,13 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN): cls, config_entry: ConfigEntry ) -> dict[str, type[ConfigSubentryFlow]]: """Return subentries supported by this integration.""" - return {"conversation": ConversationSubentryFlowHandler} + return { + "conversation": LLMSubentryFlowHandler, + "ai_task": LLMSubentryFlowHandler, + } -class ConversationSubentryFlowHandler(ConfigSubentryFlow): +class LLMSubentryFlowHandler(ConfigSubentryFlow): """Flow for managing conversation subentries.""" last_rendered_recommended = False @@ -158,7 +164,10 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow): ) -> SubentryFlowResult: """Add a subentry.""" self.is_new = True - self.options = RECOMMENDED_OPTIONS.copy() + if self._subentry_type == "ai_task": + self.options = RECOMMENDED_AI_TASK_OPTIONS.copy() + else: + self.options = RECOMMENDED_CONVERSATION_OPTIONS.copy() return await self.async_step_init() async def async_step_reconfigure( @@ -190,28 +199,34 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow): step_schema: VolDictType = {} if self.is_new: - step_schema[vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME)] = ( - str + if CONF_NAME in options: + default_name = options[CONF_NAME] + elif self._subentry_type == "ai_task": + default_name = DEFAULT_AI_TASK_NAME + else: + default_name = DEFAULT_CONVERSATION_NAME + step_schema[vol.Required(CONF_NAME, default=default_name)] = str + + if self._subentry_type == "conversation": + step_schema.update( + { + vol.Optional( + CONF_PROMPT, + description={ + "suggested_value": options.get( + CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT + ) + }, + ): TemplateSelector(), + vol.Optional(CONF_LLM_HASS_API): SelectSelector( + SelectSelectorConfig(options=hass_apis, multiple=True) + ), + } ) - step_schema.update( - { - vol.Optional( - CONF_PROMPT, - description={ - "suggested_value": options.get( - CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT - ) - }, - ): TemplateSelector(), - vol.Optional(CONF_LLM_HASS_API): SelectSelector( - SelectSelectorConfig(options=hass_apis, multiple=True) - ), - vol.Required( - CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) - ): bool, - } - ) + step_schema[ + vol.Required(CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)) + ] = bool if user_input is not None: if not user_input.get(CONF_LLM_HASS_API): diff --git a/homeassistant/components/openai_conversation/const.py b/homeassistant/components/openai_conversation/const.py index f90c05eed79..7c8b971ef12 100644 --- a/homeassistant/components/openai_conversation/const.py +++ b/homeassistant/components/openai_conversation/const.py @@ -2,10 +2,14 @@ import logging +from homeassistant.const import CONF_LLM_HASS_API +from homeassistant.helpers import llm + DOMAIN = "openai_conversation" LOGGER: logging.Logger = logging.getLogger(__package__) DEFAULT_CONVERSATION_NAME = "OpenAI Conversation" +DEFAULT_AI_TASK_NAME = "OpenAI AI Task" CONF_CHAT_MODEL = "chat_model" CONF_FILENAMES = "filenames" @@ -32,6 +36,16 @@ RECOMMENDED_WEB_SEARCH = False RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE = "medium" RECOMMENDED_WEB_SEARCH_USER_LOCATION = False +RECOMMENDED_CONVERSATION_OPTIONS = { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: [llm.LLM_API_ASSIST], + CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT, +} + +RECOMMENDED_AI_TASK_OPTIONS = { + CONF_RECOMMENDED: True, +} + UNSUPPORTED_MODELS: list[str] = [ "o1-mini", "o1-mini-2024-09-12", diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index cfcdc924fe5..24c0d27db93 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -1,74 +1,17 @@ """Conversation support for OpenAI.""" -from collections.abc import AsyncGenerator, Callable -import json -from typing import Any, Literal, cast - -import openai -from openai._streaming import AsyncStream -from openai.types.responses import ( - EasyInputMessageParam, - FunctionToolParam, - ResponseCompletedEvent, - ResponseErrorEvent, - ResponseFailedEvent, - ResponseFunctionCallArgumentsDeltaEvent, - ResponseFunctionCallArgumentsDoneEvent, - ResponseFunctionToolCall, - ResponseFunctionToolCallParam, - ResponseIncompleteEvent, - ResponseInputParam, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, - ResponseOutputMessage, - ResponseOutputMessageParam, - ResponseReasoningItem, - ResponseReasoningItemParam, - ResponseStreamEvent, - ResponseTextDeltaEvent, - ToolParam, - WebSearchToolParam, -) -from openai.types.responses.response_input_param import FunctionCallOutput -from openai.types.responses.web_search_tool_param import UserLocation -from voluptuous_openapi import convert +from typing import Literal from homeassistant.components import assist_pipeline, conversation from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant -from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import device_registry as dr, intent, llm +from homeassistant.helpers import intent from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from . import OpenAIConfigEntry -from .const import ( - CONF_CHAT_MODEL, - CONF_MAX_TOKENS, - CONF_PROMPT, - CONF_REASONING_EFFORT, - CONF_TEMPERATURE, - CONF_TOP_P, - CONF_WEB_SEARCH, - CONF_WEB_SEARCH_CITY, - CONF_WEB_SEARCH_CONTEXT_SIZE, - CONF_WEB_SEARCH_COUNTRY, - CONF_WEB_SEARCH_REGION, - CONF_WEB_SEARCH_TIMEZONE, - CONF_WEB_SEARCH_USER_LOCATION, - DEFAULT_CONVERSATION_NAME, - DOMAIN, - LOGGER, - RECOMMENDED_CHAT_MODEL, - RECOMMENDED_MAX_TOKENS, - RECOMMENDED_REASONING_EFFORT, - RECOMMENDED_TEMPERATURE, - RECOMMENDED_TOP_P, - RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE, -) - -# Max number of back and forth with the LLM to generate a response -MAX_TOOL_ITERATIONS = 10 +from .const import CONF_PROMPT, DEFAULT_CONVERSATION_NAME, DOMAIN +from .entity import OpenAILLMBaseEntity async def async_setup_entry( @@ -87,152 +30,10 @@ async def async_setup_entry( ) -def _format_tool( - tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None -) -> FunctionToolParam: - """Format tool specification.""" - return FunctionToolParam( - type="function", - name=tool.name, - parameters=convert(tool.parameters, custom_serializer=custom_serializer), - description=tool.description, - strict=False, - ) - - -def _convert_content_to_param( - content: conversation.Content, -) -> ResponseInputParam: - """Convert any native chat message for this agent to the native format.""" - messages: ResponseInputParam = [] - if isinstance(content, conversation.ToolResultContent): - return [ - FunctionCallOutput( - type="function_call_output", - call_id=content.tool_call_id, - output=json.dumps(content.tool_result), - ) - ] - - if content.content: - role: Literal["user", "assistant", "system", "developer"] = content.role - if role == "system": - role = "developer" - messages.append( - EasyInputMessageParam(type="message", role=role, content=content.content) - ) - - if isinstance(content, conversation.AssistantContent) and content.tool_calls: - messages.extend( - ResponseFunctionToolCallParam( - type="function_call", - name=tool_call.tool_name, - arguments=json.dumps(tool_call.tool_args), - call_id=tool_call.id, - ) - for tool_call in content.tool_calls - ) - return messages - - -async def _transform_stream( - chat_log: conversation.ChatLog, - result: AsyncStream[ResponseStreamEvent], - messages: ResponseInputParam, -) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: - """Transform an OpenAI delta stream into HA format.""" - async for event in result: - LOGGER.debug("Received event: %s", event) - - if isinstance(event, ResponseOutputItemAddedEvent): - if isinstance(event.item, ResponseOutputMessage): - yield {"role": event.item.role} - elif isinstance(event.item, ResponseFunctionToolCall): - # OpenAI has tool calls as individual events - # while HA puts tool calls inside the assistant message. - # We turn them into individual assistant content for HA - # to ensure that tools are called as soon as possible. - yield {"role": "assistant"} - current_tool_call = event.item - elif isinstance(event, ResponseOutputItemDoneEvent): - item = event.item.model_dump() - item.pop("status", None) - if isinstance(event.item, ResponseReasoningItem): - messages.append(cast(ResponseReasoningItemParam, item)) - elif isinstance(event.item, ResponseOutputMessage): - messages.append(cast(ResponseOutputMessageParam, item)) - elif isinstance(event.item, ResponseFunctionToolCall): - messages.append(cast(ResponseFunctionToolCallParam, item)) - elif isinstance(event, ResponseTextDeltaEvent): - yield {"content": event.delta} - elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent): - current_tool_call.arguments += event.delta - elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent): - current_tool_call.status = "completed" - yield { - "tool_calls": [ - llm.ToolInput( - id=current_tool_call.call_id, - tool_name=current_tool_call.name, - tool_args=json.loads(current_tool_call.arguments), - ) - ] - } - elif isinstance(event, ResponseCompletedEvent): - if event.response.usage is not None: - chat_log.async_trace( - { - "stats": { - "input_tokens": event.response.usage.input_tokens, - "output_tokens": event.response.usage.output_tokens, - } - } - ) - elif isinstance(event, ResponseIncompleteEvent): - if event.response.usage is not None: - chat_log.async_trace( - { - "stats": { - "input_tokens": event.response.usage.input_tokens, - "output_tokens": event.response.usage.output_tokens, - } - } - ) - - if ( - event.response.incomplete_details - and event.response.incomplete_details.reason - ): - reason: str = event.response.incomplete_details.reason - else: - reason = "unknown reason" - - if reason == "max_output_tokens": - reason = "max output tokens reached" - elif reason == "content_filter": - reason = "content filter triggered" - - raise HomeAssistantError(f"OpenAI response incomplete: {reason}") - elif isinstance(event, ResponseFailedEvent): - if event.response.usage is not None: - chat_log.async_trace( - { - "stats": { - "input_tokens": event.response.usage.input_tokens, - "output_tokens": event.response.usage.output_tokens, - } - } - ) - reason = "unknown reason" - if event.response.error is not None: - reason = event.response.error.message - raise HomeAssistantError(f"OpenAI response failed: {reason}") - elif isinstance(event, ResponseErrorEvent): - raise HomeAssistantError(f"OpenAI response error: {event.message}") - - class OpenAIConversationEntity( - conversation.ConversationEntity, conversation.AbstractConversationAgent + conversation.ConversationEntity, + conversation.AbstractConversationAgent, + OpenAILLMBaseEntity, ): """OpenAI conversation agent.""" @@ -240,17 +41,8 @@ class OpenAIConversationEntity( def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None: """Initialize the agent.""" - self.entry = entry - self.subentry = subentry + super().__init__(entry, subentry) self._attr_name = subentry.title or DEFAULT_CONVERSATION_NAME - self._attr_unique_id = subentry.subentry_id - self._attr_device_info = dr.DeviceInfo( - identifiers={(DOMAIN, entry.entry_id)}, - name=entry.title, - manufacturer="OpenAI", - model="ChatGPT", - entry_type=dr.DeviceEntryType.SERVICE, - ) if self.subentry.data.get(CONF_LLM_HASS_API): self._attr_supported_features = ( conversation.ConversationEntityFeature.CONTROL @@ -306,91 +98,6 @@ class OpenAIConversationEntity( continue_conversation=chat_log.continue_conversation, ) - async def _async_handle_chat_log( - self, - chat_log: conversation.ChatLog, - ) -> None: - """Generate an answer for the chat log.""" - options = self.subentry.data - - tools: list[ToolParam] | None = None - if chat_log.llm_api: - tools = [ - _format_tool(tool, chat_log.llm_api.custom_serializer) - for tool in chat_log.llm_api.tools - ] - - if options.get(CONF_WEB_SEARCH): - web_search = WebSearchToolParam( - type="web_search_preview", - search_context_size=options.get( - CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE - ), - ) - if options.get(CONF_WEB_SEARCH_USER_LOCATION): - web_search["user_location"] = UserLocation( - type="approximate", - city=options.get(CONF_WEB_SEARCH_CITY, ""), - region=options.get(CONF_WEB_SEARCH_REGION, ""), - country=options.get(CONF_WEB_SEARCH_COUNTRY, ""), - timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""), - ) - if tools is None: - tools = [] - tools.append(web_search) - - model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) - messages = [ - m - for content in chat_log.content - for m in _convert_content_to_param(content) - ] - - client = self.entry.runtime_data - - # To prevent infinite loops, we limit the number of iterations - for _iteration in range(MAX_TOOL_ITERATIONS): - model_args = { - "model": model, - "input": messages, - "max_output_tokens": options.get( - CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS - ), - "top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P), - "temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), - "user": chat_log.conversation_id, - "stream": True, - } - if tools: - model_args["tools"] = tools - - if model.startswith("o"): - model_args["reasoning"] = { - "effort": options.get( - CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT - ) - } - else: - model_args["store"] = False - - try: - result = await client.responses.create(**model_args) - except openai.RateLimitError as err: - LOGGER.error("Rate limited by OpenAI: %s", err) - raise HomeAssistantError("Rate limited or insufficient funds") from err - except openai.OpenAIError as err: - LOGGER.error("Error talking to OpenAI: %s", err) - raise HomeAssistantError("Error talking to OpenAI") from err - - async for content in chat_log.async_add_delta_content_stream( - self.entity_id, _transform_stream(chat_log, result, messages) - ): - if not isinstance(content, conversation.AssistantContent): - messages.extend(_convert_content_to_param(content)) - - if not chat_log.unresponded_tool_results: - break - async def _async_entry_update_listener( self, hass: HomeAssistant, entry: ConfigEntry ) -> None: diff --git a/homeassistant/components/openai_conversation/entity.py b/homeassistant/components/openai_conversation/entity.py new file mode 100644 index 00000000000..34b04294731 --- /dev/null +++ b/homeassistant/components/openai_conversation/entity.py @@ -0,0 +1,313 @@ +"""Base class for OpenAI Conversation entities.""" + +from collections.abc import AsyncGenerator, Callable +import json +from typing import Any, Literal, cast + +import openai +from openai._streaming import AsyncStream +from openai.types.responses import ( + EasyInputMessageParam, + FunctionToolParam, + ResponseCompletedEvent, + ResponseErrorEvent, + ResponseFailedEvent, + ResponseFunctionCallArgumentsDeltaEvent, + ResponseFunctionCallArgumentsDoneEvent, + ResponseFunctionToolCall, + ResponseFunctionToolCallParam, + ResponseIncompleteEvent, + ResponseInputParam, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, + ResponseOutputMessageParam, + ResponseReasoningItem, + ResponseReasoningItemParam, + ResponseStreamEvent, + ResponseTextDeltaEvent, + ToolParam, + WebSearchToolParam, +) +from openai.types.responses.response_input_param import FunctionCallOutput +from openai.types.responses.web_search_tool_param import UserLocation +from voluptuous_openapi import convert + +from homeassistant.components import conversation +from homeassistant.config_entries import ConfigSubentry +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import device_registry as dr, llm +from homeassistant.helpers.entity import Entity + +from . import OpenAIConfigEntry +from .const import ( + CONF_CHAT_MODEL, + CONF_MAX_TOKENS, + CONF_REASONING_EFFORT, + CONF_TEMPERATURE, + CONF_TOP_P, + CONF_WEB_SEARCH, + CONF_WEB_SEARCH_CITY, + CONF_WEB_SEARCH_CONTEXT_SIZE, + CONF_WEB_SEARCH_COUNTRY, + CONF_WEB_SEARCH_REGION, + CONF_WEB_SEARCH_TIMEZONE, + CONF_WEB_SEARCH_USER_LOCATION, + DOMAIN, + LOGGER, + RECOMMENDED_CHAT_MODEL, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_REASONING_EFFORT, + RECOMMENDED_TEMPERATURE, + RECOMMENDED_TOP_P, + RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE, +) + +# Max number of back and forth with the LLM to generate a response +MAX_TOOL_ITERATIONS = 10 + + +def _format_tool( + tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None +) -> FunctionToolParam: + """Format tool specification.""" + return FunctionToolParam( + type="function", + name=tool.name, + parameters=convert(tool.parameters, custom_serializer=custom_serializer), + description=tool.description, + strict=False, + ) + + +def _convert_content_to_param( + content: conversation.Content, +) -> ResponseInputParam: + """Convert any native chat message for this agent to the native format.""" + messages: ResponseInputParam = [] + if isinstance(content, conversation.ToolResultContent): + return [ + FunctionCallOutput( + type="function_call_output", + call_id=content.tool_call_id, + output=json.dumps(content.tool_result), + ) + ] + + if content.content: + role: Literal["user", "assistant", "system", "developer"] = content.role + if role == "system": + role = "developer" + messages.append( + EasyInputMessageParam(type="message", role=role, content=content.content) + ) + + if isinstance(content, conversation.AssistantContent) and content.tool_calls: + messages.extend( + ResponseFunctionToolCallParam( + type="function_call", + name=tool_call.tool_name, + arguments=json.dumps(tool_call.tool_args), + call_id=tool_call.id, + ) + for tool_call in content.tool_calls + ) + return messages + + +async def _transform_stream( + chat_log: conversation.ChatLog, + result: AsyncStream[ResponseStreamEvent], + messages: ResponseInputParam, +) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: + """Transform an OpenAI delta stream into HA format.""" + async for event in result: + LOGGER.debug("Received event: %s", event) + + if isinstance(event, ResponseOutputItemAddedEvent): + if isinstance(event.item, ResponseOutputMessage): + yield {"role": event.item.role} + elif isinstance(event.item, ResponseFunctionToolCall): + # OpenAI has tool calls as individual events + # while HA puts tool calls inside the assistant message. + # We turn them into individual assistant content for HA + # to ensure that tools are called as soon as possible. + yield {"role": "assistant"} + current_tool_call = event.item + elif isinstance(event, ResponseOutputItemDoneEvent): + item = event.item.model_dump() + item.pop("status", None) + if isinstance(event.item, ResponseReasoningItem): + messages.append(cast(ResponseReasoningItemParam, item)) + elif isinstance(event.item, ResponseOutputMessage): + messages.append(cast(ResponseOutputMessageParam, item)) + elif isinstance(event.item, ResponseFunctionToolCall): + messages.append(cast(ResponseFunctionToolCallParam, item)) + elif isinstance(event, ResponseTextDeltaEvent): + yield {"content": event.delta} + elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent): + current_tool_call.arguments += event.delta + elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent): + current_tool_call.status = "completed" + yield { + "tool_calls": [ + llm.ToolInput( + id=current_tool_call.call_id, + tool_name=current_tool_call.name, + tool_args=json.loads(current_tool_call.arguments), + ) + ] + } + elif isinstance(event, ResponseCompletedEvent): + if event.response.usage is not None: + chat_log.async_trace( + { + "stats": { + "input_tokens": event.response.usage.input_tokens, + "output_tokens": event.response.usage.output_tokens, + } + } + ) + elif isinstance(event, ResponseIncompleteEvent): + if event.response.usage is not None: + chat_log.async_trace( + { + "stats": { + "input_tokens": event.response.usage.input_tokens, + "output_tokens": event.response.usage.output_tokens, + } + } + ) + + if ( + event.response.incomplete_details + and event.response.incomplete_details.reason + ): + reason: str = event.response.incomplete_details.reason + else: + reason = "unknown reason" + + if reason == "max_output_tokens": + reason = "max output tokens reached" + elif reason == "content_filter": + reason = "content filter triggered" + + raise HomeAssistantError(f"OpenAI response incomplete: {reason}") + elif isinstance(event, ResponseFailedEvent): + if event.response.usage is not None: + chat_log.async_trace( + { + "stats": { + "input_tokens": event.response.usage.input_tokens, + "output_tokens": event.response.usage.output_tokens, + } + } + ) + reason = "unknown reason" + if event.response.error is not None: + reason = event.response.error.message + raise HomeAssistantError(f"OpenAI response failed: {reason}") + elif isinstance(event, ResponseErrorEvent): + raise HomeAssistantError(f"OpenAI response error: {event.message}") + + +class OpenAILLMBaseEntity(Entity): + """OpenAI conversation agent.""" + + def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None: + """Initialize the agent.""" + self.entry = entry + self.subentry = subentry + self._attr_unique_id = subentry.subentry_id + self._attr_device_info = dr.DeviceInfo( + identifiers={(DOMAIN, entry.entry_id)}, + name=entry.title, + manufacturer="OpenAI", + model="ChatGPT", + entry_type=dr.DeviceEntryType.SERVICE, + ) + + async def _async_handle_chat_log( + self, + chat_log: conversation.ChatLog, + ) -> None: + """Generate an answer for the chat log.""" + options = self.subentry.data + + tools: list[ToolParam] | None = None + if chat_log.llm_api: + tools = [ + _format_tool(tool, chat_log.llm_api.custom_serializer) + for tool in chat_log.llm_api.tools + ] + + if options.get(CONF_WEB_SEARCH): + web_search = WebSearchToolParam( + type="web_search_preview", + search_context_size=options.get( + CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE + ), + ) + if options.get(CONF_WEB_SEARCH_USER_LOCATION): + web_search["user_location"] = UserLocation( + type="approximate", + city=options.get(CONF_WEB_SEARCH_CITY, ""), + region=options.get(CONF_WEB_SEARCH_REGION, ""), + country=options.get(CONF_WEB_SEARCH_COUNTRY, ""), + timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""), + ) + if tools is None: + tools = [] + tools.append(web_search) + + model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) + messages = [ + m + for content in chat_log.content + for m in _convert_content_to_param(content) + ] + + client = self.entry.runtime_data + + # To prevent infinite loops, we limit the number of iterations + for _iteration in range(MAX_TOOL_ITERATIONS): + model_args = { + "model": model, + "input": messages, + "max_output_tokens": options.get( + CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS + ), + "top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P), + "temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), + "user": chat_log.conversation_id, + "stream": True, + } + if tools: + model_args["tools"] = tools + + if model.startswith("o"): + model_args["reasoning"] = { + "effort": options.get( + CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT + ) + } + else: + model_args["store"] = False + + try: + result = await client.responses.create(**model_args) + except openai.RateLimitError as err: + LOGGER.error("Rate limited by OpenAI: %s", err) + raise HomeAssistantError("Rate limited or insufficient funds") from err + except openai.OpenAIError as err: + LOGGER.error("Error talking to OpenAI: %s", err) + raise HomeAssistantError("Error talking to OpenAI") from err + + async for content in chat_log.async_add_delta_content_stream( + self.entity_id, _transform_stream(chat_log, result, messages) + ): + if not isinstance(content, conversation.AssistantContent): + messages.extend(_convert_content_to_param(content)) + + if not chat_log.unresponded_tool_results: + break diff --git a/homeassistant/components/openai_conversation/strings.json b/homeassistant/components/openai_conversation/strings.json index 696ebae4b6f..75488298f45 100644 --- a/homeassistant/components/openai_conversation/strings.json +++ b/homeassistant/components/openai_conversation/strings.json @@ -64,6 +64,51 @@ "error": { "model_not_supported": "This model is not supported, please select a different model" } + }, + "ai_task": { + "initiate_flow": { + "user": "Add AI task service", + "reconfigure": "Reconfigure AI task service" + }, + "entry_type": "AI task service", + "step": { + "init": { + "data": { + "name": "[%key:common::config_flow::data::name%]", + "recommended": "[%key:component::openai_conversation::config_subentries::conversation::step::init::data::recommended%]" + } + }, + "advanced": { + "title": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::title%]", + "data": { + "chat_model": "[%key:common::generic::model%]", + "max_tokens": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::max_tokens%]", + "temperature": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::temperature%]", + "top_p": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::top_p%]" + } + }, + "model": { + "title": "[%key:component::openai_conversation::config_subentries::conversation::step::model::title%]", + "data": { + "reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::reasoning_effort%]", + "web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::web_search%]", + "search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::search_context_size%]", + "user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::user_location%]" + }, + "data_description": { + "reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::reasoning_effort%]", + "web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::web_search%]", + "search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::search_context_size%]", + "user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::user_location%]" + } + } + }, + "abort": { + "reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]" + }, + "error": { + "model_not_supported": "[%key:component::openai_conversation::config_subentries::conversation::error::model_not_supported%]" + } } }, "selector": { diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index c2481ae3fa3..ca3a78f8046 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -3420,6 +3420,11 @@ class ConfigSubentryFlow( """Return config entry id.""" return self.handler[0] + @property + def _subentry_type(self) -> str: + """Return type of subentry we are editing/creating.""" + return self.handler[1] + @callback def _get_entry(self) -> ConfigEntry: """Return the config entry linked to the current context.""" diff --git a/tests/components/openai_conversation/common.py b/tests/components/openai_conversation/common.py new file mode 100644 index 00000000000..332eeedf1d5 --- /dev/null +++ b/tests/components/openai_conversation/common.py @@ -0,0 +1,87 @@ +"""Common utilities for OpenAI conversation tests.""" + +from openai.types.responses import ( + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, + ResponseOutputText, + ResponseStreamEvent, + ResponseTextDeltaEvent, + ResponseTextDoneEvent, +) + + +def create_message_item( + id: str, text: str | list[str], output_index: int +) -> list[ResponseStreamEvent]: + """Create a message item.""" + if isinstance(text, str): + text = [text] + + content = ResponseOutputText(annotations=[], text="", type="output_text") + events = [ + ResponseOutputItemAddedEvent( + item=ResponseOutputMessage( + id=id, + content=[], + type="message", + role="assistant", + status="in_progress", + ), + output_index=output_index, + type="response.output_item.added", + ), + ResponseContentPartAddedEvent( + content_index=0, + item_id=id, + output_index=output_index, + part=content, + type="response.content_part.added", + ), + ] + + content.text = "".join(text) + events.extend( + ResponseTextDeltaEvent( + content_index=0, + delta=delta, + item_id=id, + output_index=output_index, + type="response.output_text.delta", + ) + for delta in text + ) + + events.extend( + [ + ResponseTextDoneEvent( + content_index=0, + item_id=id, + output_index=output_index, + text="".join(text), + type="response.output_text.done", + ), + ResponseContentPartDoneEvent( + content_index=0, + item_id=id, + output_index=output_index, + part=content, + type="response.content_part.done", + ), + ResponseOutputItemDoneEvent( + item=ResponseOutputMessage( + id=id, + content=[content], + role="assistant", + status="completed", + type="message", + ), + output_index=output_index, + type="response.output_item.done", + ), + ] + ) + + return events diff --git a/tests/components/openai_conversation/conftest.py b/tests/components/openai_conversation/conftest.py index aa17c333a79..b8fd6c4007b 100644 --- a/tests/components/openai_conversation/conftest.py +++ b/tests/components/openai_conversation/conftest.py @@ -1,16 +1,35 @@ """Tests helpers.""" -from unittest.mock import patch +from collections.abc import Generator +from unittest.mock import AsyncMock, patch +from openai.types import ResponseFormatText +from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseCreatedEvent, + ResponseError, + ResponseErrorEvent, + ResponseFailedEvent, + ResponseIncompleteEvent, + ResponseInProgressEvent, + ResponseOutputItemDoneEvent, + ResponseTextConfig, +) +from openai.types.responses.response import IncompleteDetails import pytest -from homeassistant.components.openai_conversation.const import DEFAULT_CONVERSATION_NAME +from homeassistant.components.openai_conversation.const import ( + DEFAULT_AI_TASK_NAME, + DEFAULT_CONVERSATION_NAME, +) from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant from homeassistant.helpers import llm from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry +from tests.components.conversation import mock_chat_log # noqa: F401 @pytest.fixture @@ -29,7 +48,13 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry: "subentry_type": "conversation", "title": DEFAULT_CONVERSATION_NAME, "unique_id": None, - } + }, + { + "data": {}, + "subentry_type": "ai_task", + "title": DEFAULT_AI_TASK_NAME, + "unique_id": None, + }, ], ) entry.add_to_hass(hass) @@ -65,3 +90,89 @@ async def mock_init_component( async def setup_ha(hass: HomeAssistant) -> None: """Set up Home Assistant.""" assert await async_setup_component(hass, "homeassistant", {}) + + +@pytest.fixture +def mock_create_stream() -> Generator[AsyncMock]: + """Mock stream response.""" + + async def mock_generator(events, **kwargs): + response = Response( + id="resp_A", + created_at=1700000000, + error=None, + incomplete_details=None, + instructions=kwargs.get("instructions"), + metadata=kwargs.get("metadata", {}), + model=kwargs.get("model", "gpt-4o-mini"), + object="response", + output=[], + parallel_tool_calls=kwargs.get("parallel_tool_calls", True), + temperature=kwargs.get("temperature", 1.0), + tool_choice=kwargs.get("tool_choice", "auto"), + tools=kwargs.get("tools", []), + top_p=kwargs.get("top_p", 1.0), + max_output_tokens=kwargs.get("max_output_tokens", 100000), + previous_response_id=kwargs.get("previous_response_id"), + reasoning=kwargs.get("reasoning"), + status="in_progress", + text=kwargs.get( + "text", ResponseTextConfig(format=ResponseFormatText(type="text")) + ), + truncation=kwargs.get("truncation", "disabled"), + usage=None, + user=kwargs.get("user"), + store=kwargs.get("store", True), + ) + yield ResponseCreatedEvent( + response=response, + type="response.created", + ) + yield ResponseInProgressEvent( + response=response, + type="response.in_progress", + ) + response.status = "completed" + + for value in events: + if isinstance(value, ResponseOutputItemDoneEvent): + response.output.append(value.item) + elif isinstance(value, IncompleteDetails): + response.status = "incomplete" + response.incomplete_details = value + break + if isinstance(value, ResponseError): + response.status = "failed" + response.error = value + break + + yield value + + if isinstance(value, ResponseErrorEvent): + return + + if response.status == "incomplete": + yield ResponseIncompleteEvent( + response=response, + type="response.incomplete", + ) + elif response.status == "failed": + yield ResponseFailedEvent( + response=response, + type="response.failed", + ) + else: + yield ResponseCompletedEvent( + response=response, + type="response.completed", + ) + + with patch( + "openai.resources.responses.AsyncResponses.create", + AsyncMock(), + ) as mock_create: + mock_create.side_effect = lambda **kwargs: mock_generator( + mock_create.return_value.pop(0), **kwargs + ) + + yield mock_create diff --git a/tests/components/openai_conversation/test_ai_task.py b/tests/components/openai_conversation/test_ai_task.py new file mode 100644 index 00000000000..54d523420a6 --- /dev/null +++ b/tests/components/openai_conversation/test_ai_task.py @@ -0,0 +1,33 @@ +"""Test AI Task platform of OpenAI Conversation integration.""" + +from unittest.mock import AsyncMock + +import pytest + +from homeassistant.components import ai_task +from homeassistant.core import HomeAssistant + +from .common import create_message_item + +from tests.common import MockConfigEntry + + +@pytest.mark.usefixtures("mock_init_component") +async def test_ai_task_generate_text( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_create_stream: AsyncMock, +) -> None: + """Test that AI task can generate text.""" + entity_id = "ai_task.openai_ai_task" + mock_create_stream.return_value = [ + create_message_item(id="msg_A", text="Hi there!", output_index=0) + ] + + result = await ai_task.async_generate_text( + hass, + task_name="Test Task", + entity_id=entity_id, + instructions="Test prompt", + ) + assert result.text == "Hi there!" diff --git a/tests/components/openai_conversation/test_config_flow.py b/tests/components/openai_conversation/test_config_flow.py index 299e3f02ccb..714d2d68585 100644 --- a/tests/components/openai_conversation/test_config_flow.py +++ b/tests/components/openai_conversation/test_config_flow.py @@ -8,7 +8,6 @@ from openai.types.responses import Response, ResponseOutputMessage, ResponseOutp import pytest from homeassistant import config_entries -from homeassistant.components.openai_conversation.config_flow import RECOMMENDED_OPTIONS from homeassistant.components.openai_conversation.const import ( CONF_CHAT_MODEL, CONF_MAX_TOKENS, @@ -24,9 +23,12 @@ from homeassistant.components.openai_conversation.const import ( CONF_WEB_SEARCH_REGION, CONF_WEB_SEARCH_TIMEZONE, CONF_WEB_SEARCH_USER_LOCATION, + DEFAULT_AI_TASK_NAME, DEFAULT_CONVERSATION_NAME, DOMAIN, + RECOMMENDED_AI_TASK_OPTIONS, RECOMMENDED_CHAT_MODEL, + RECOMMENDED_CONVERSATION_OPTIONS, RECOMMENDED_MAX_TOKENS, RECOMMENDED_TOP_P, ) @@ -77,10 +79,16 @@ async def test_form(hass: HomeAssistant) -> None: assert result2["subentries"] == [ { "subentry_type": "conversation", - "data": RECOMMENDED_OPTIONS, + "data": RECOMMENDED_CONVERSATION_OPTIONS, "title": DEFAULT_CONVERSATION_NAME, "unique_id": None, - } + }, + { + "subentry_type": "ai_task", + "data": RECOMMENDED_AI_TASK_OPTIONS, + "title": DEFAULT_AI_TASK_NAME, + "unique_id": None, + }, ] assert len(mock_setup_entry.mock_calls) == 1 @@ -104,19 +112,56 @@ async def test_creating_conversation_subentry( result2 = await hass.config_entries.subentries.async_configure( result["flow_id"], - {"name": "My Custom Agent", **RECOMMENDED_OPTIONS}, + {"name": "My Custom Agent", **RECOMMENDED_CONVERSATION_OPTIONS}, ) await hass.async_block_till_done() assert result2["type"] is FlowResultType.CREATE_ENTRY assert result2["title"] == "My Custom Agent" - processed_options = RECOMMENDED_OPTIONS.copy() + processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy() processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip() assert result2["data"] == processed_options +async def test_creating_ai_task_subentry( + hass: HomeAssistant, mock_config_entry, mock_init_component +) -> None: + """Test creating an AI task subentry.""" + with patch("openai.resources.models.AsyncModels.list"): + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, "ai_task"), + context={"source": config_entries.SOURCE_USER}, + ) + + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "init" + assert not result["errors"] + + old_subentries = set(mock_config_entry.subentries) + + with patch("openai.resources.models.AsyncModels.list"): + result2 = await hass.config_entries.subentries.async_configure( + result["flow_id"], + {"name": "My AI Task", **RECOMMENDED_AI_TASK_OPTIONS}, + ) + await hass.async_block_till_done() + + assert result2["type"] is FlowResultType.CREATE_ENTRY + assert result2["title"] == "My AI Task" + assert result2["data"] == RECOMMENDED_AI_TASK_OPTIONS + + assert len(mock_config_entry.subentries) == 3 + + new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0] + new_subentry = mock_config_entry.subentries[new_subentry_id] + + assert new_subentry.subentry_type == "ai_task" + assert new_subentry.data == RECOMMENDED_AI_TASK_OPTIONS + assert new_subentry.title == "My AI Task" + + async def test_subentry_recommended( hass: HomeAssistant, mock_config_entry, mock_init_component ) -> None: diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 8621465bd14..b44d33cdf9d 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -1,35 +1,20 @@ """Tests for the OpenAI integration.""" -from collections.abc import Generator from unittest.mock import AsyncMock, patch import httpx from openai import AuthenticationError, RateLimitError -from openai.types import ResponseFormatText from openai.types.responses import ( - Response, - ResponseCompletedEvent, - ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent, - ResponseCreatedEvent, ResponseError, ResponseErrorEvent, - ResponseFailedEvent, ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseFunctionToolCall, ResponseFunctionWebSearch, - ResponseIncompleteEvent, - ResponseInProgressEvent, ResponseOutputItemAddedEvent, ResponseOutputItemDoneEvent, - ResponseOutputMessage, - ResponseOutputText, ResponseReasoningItem, ResponseStreamEvent, - ResponseTextConfig, - ResponseTextDeltaEvent, - ResponseTextDoneEvent, ResponseWebSearchCallCompletedEvent, ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallSearchingEvent, @@ -54,6 +39,8 @@ from homeassistant.core import Context, HomeAssistant from homeassistant.helpers import intent from homeassistant.setup import async_setup_component +from .common import create_message_item + from tests.common import MockConfigEntry from tests.components.conversation import ( MockChatLog, @@ -61,92 +48,6 @@ from tests.components.conversation import ( ) -@pytest.fixture -def mock_create_stream() -> Generator[AsyncMock]: - """Mock stream response.""" - - async def mock_generator(events, **kwargs): - response = Response( - id="resp_A", - created_at=1700000000, - error=None, - incomplete_details=None, - instructions=kwargs.get("instructions"), - metadata=kwargs.get("metadata", {}), - model=kwargs.get("model", "gpt-4o-mini"), - object="response", - output=[], - parallel_tool_calls=kwargs.get("parallel_tool_calls", True), - temperature=kwargs.get("temperature", 1.0), - tool_choice=kwargs.get("tool_choice", "auto"), - tools=kwargs.get("tools"), - top_p=kwargs.get("top_p", 1.0), - max_output_tokens=kwargs.get("max_output_tokens", 100000), - previous_response_id=kwargs.get("previous_response_id"), - reasoning=kwargs.get("reasoning"), - status="in_progress", - text=kwargs.get( - "text", ResponseTextConfig(format=ResponseFormatText(type="text")) - ), - truncation=kwargs.get("truncation", "disabled"), - usage=None, - user=kwargs.get("user"), - store=kwargs.get("store", True), - ) - yield ResponseCreatedEvent( - response=response, - type="response.created", - ) - yield ResponseInProgressEvent( - response=response, - type="response.in_progress", - ) - response.status = "completed" - - for value in events: - if isinstance(value, ResponseOutputItemDoneEvent): - response.output.append(value.item) - elif isinstance(value, IncompleteDetails): - response.status = "incomplete" - response.incomplete_details = value - break - if isinstance(value, ResponseError): - response.status = "failed" - response.error = value - break - - yield value - - if isinstance(value, ResponseErrorEvent): - return - - if response.status == "incomplete": - yield ResponseIncompleteEvent( - response=response, - type="response.incomplete", - ) - elif response.status == "failed": - yield ResponseFailedEvent( - response=response, - type="response.failed", - ) - else: - yield ResponseCompletedEvent( - response=response, - type="response.completed", - ) - - with patch( - "openai.resources.responses.AsyncResponses.create", - AsyncMock(), - ) as mock_create: - mock_create.side_effect = lambda **kwargs: mock_generator( - mock_create.return_value.pop(0), **kwargs - ) - - yield mock_create - - async def test_entity( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -341,80 +242,6 @@ async def test_conversation_agent( assert agent.supported_languages == "*" -def create_message_item( - id: str, text: str | list[str], output_index: int -) -> list[ResponseStreamEvent]: - """Create a message item.""" - if isinstance(text, str): - text = [text] - - content = ResponseOutputText(annotations=[], text="", type="output_text") - events = [ - ResponseOutputItemAddedEvent( - item=ResponseOutputMessage( - id=id, - content=[], - type="message", - role="assistant", - status="in_progress", - ), - output_index=output_index, - type="response.output_item.added", - ), - ResponseContentPartAddedEvent( - content_index=0, - item_id=id, - output_index=output_index, - part=content, - type="response.content_part.added", - ), - ] - - content.text = "".join(text) - events.extend( - ResponseTextDeltaEvent( - content_index=0, - delta=delta, - item_id=id, - output_index=output_index, - type="response.output_text.delta", - ) - for delta in text - ) - - events.extend( - [ - ResponseTextDoneEvent( - content_index=0, - item_id=id, - output_index=output_index, - text="".join(text), - type="response.output_text.done", - ), - ResponseContentPartDoneEvent( - content_index=0, - item_id=id, - output_index=output_index, - part=content, - type="response.content_part.done", - ), - ResponseOutputItemDoneEvent( - item=ResponseOutputMessage( - id=id, - content=[content], - role="assistant", - status="completed", - type="message", - ), - output_index=output_index, - type="response.output_item.done", - ), - ] - ) - - return events - - def create_function_tool_call_item( id: str, arguments: str | list[str], call_id: str, name: str, output_index: int ) -> list[ResponseStreamEvent]: