Add AI Task platform to OpenAI

This commit is contained in:
Paulus Schoutsen
2025-06-22 03:44:43 +00:00
parent 3d5df43ebb
commit 2029996d19
13 changed files with 799 additions and 520 deletions

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import base64 import base64
from mimetypes import guess_file_type from mimetypes import guess_file_type
from pathlib import Path from pathlib import Path
from types import MappingProxyType
import openai import openai
from openai.types.images_response import ImagesResponse from openai.types.images_response import ImagesResponse
@@ -48,9 +49,11 @@ from .const import (
CONF_REASONING_EFFORT, CONF_REASONING_EFFORT,
CONF_TEMPERATURE, CONF_TEMPERATURE,
CONF_TOP_P, CONF_TOP_P,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DOMAIN, DOMAIN,
LOGGER, LOGGER,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT,
@@ -61,7 +64,10 @@ from .const import (
SERVICE_GENERATE_IMAGE = "generate_image" SERVICE_GENERATE_IMAGE = "generate_image"
SERVICE_GENERATE_CONTENT = "generate_content" SERVICE_GENERATE_CONTENT = "generate_content"
PLATFORMS = (Platform.CONVERSATION,) PLATFORMS = (
Platform.AI_TASK,
Platform.CONVERSATION,
)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient] type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
@@ -295,7 +301,7 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) ->
if entry.version == 1: if entry.version == 1:
# Migrate from version 1 to version 2 # Migrate from version 1 to version 2
# Move conversation-specific options to a subentry # Move conversation-specific options to a subentry
subentry = ConfigSubentry( conversation_subentry = ConfigSubentry(
data=entry.options, data=entry.options,
subentry_type="conversation", subentry_type="conversation",
title=DEFAULT_CONVERSATION_NAME, title=DEFAULT_CONVERSATION_NAME,
@@ -303,7 +309,16 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) ->
) )
hass.config_entries.async_add_subentry( hass.config_entries.async_add_subentry(
entry, entry,
subentry, conversation_subentry,
)
hass.config_entries.async_add_subentry(
entry,
ConfigSubentry(
data=MappingProxyType(RECOMMENDED_AI_TASK_OPTIONS),
subentry_type="ai_task",
title=DEFAULT_AI_TASK_NAME,
unique_id=None,
),
) )
# Migrate conversation entity to be linked to subentry # Migrate conversation entity to be linked to subentry
@@ -312,8 +327,8 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) ->
if entity_entry.domain == Platform.CONVERSATION: if entity_entry.domain == Platform.CONVERSATION:
ent_reg.async_update_entity( ent_reg.async_update_entity(
entity_entry.entity_id, entity_entry.entity_id,
config_subentry_id=subentry.subentry_id, config_subentry_id=conversation_subentry.subentry_id,
new_unique_id=subentry.subentry_id, new_unique_id=conversation_subentry.subentry_id,
) )
break break

View File

@@ -0,0 +1,62 @@
"""AI Task integration for OpenAI Conversation."""
from __future__ import annotations
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import OpenAIConfigEntry
from .const import DEFAULT_AI_TASK_NAME, LOGGER
from .entity import OpenAILLMBaseEntity
ERROR_GETTING_RESPONSE = "Sorry, I had a problem getting a response from OpenAI."
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task":
continue
async_add_entities(
[OpenAILLMTaskEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
class OpenAILLMTaskEntity(ai_task.AITaskEntity, OpenAILLMBaseEntity):
"""OpenAI AI Task entity."""
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_TEXT
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the agent."""
super().__init__(entry, subentry)
self._attr_name = subentry.title or DEFAULT_AI_TASK_NAME
async def _async_generate_text(
self,
task: ai_task.GenTextTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenTextTaskResult:
"""Handle a generate text task."""
await self._async_handle_chat_log(chat_log)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
LOGGER.error(
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
chat_log.content[-1],
)
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
return ai_task.GenTextTaskResult(
conversation_id=chat_log.conversation_id,
text=chat_log.content[-1].content or "",
)

View File

@@ -54,9 +54,12 @@ from .const import (
CONF_WEB_SEARCH_REGION, CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE, CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION, CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DOMAIN, DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE, RECOMMENDED_TEMPERATURE,
@@ -76,12 +79,6 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
} }
) )
RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
}
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect. """Validate the user input allows us to connect.
@@ -126,10 +123,16 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
subentries=[ subentries=[
{ {
"subentry_type": "conversation", "subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS, "data": RECOMMENDED_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME, "title": DEFAULT_CONVERSATION_NAME,
"unique_id": None, "unique_id": None,
} },
{
"subentry_type": "ai_task",
"data": RECOMMENDED_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
], ],
) )
@@ -143,10 +146,13 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
cls, config_entry: ConfigEntry cls, config_entry: ConfigEntry
) -> dict[str, type[ConfigSubentryFlow]]: ) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration.""" """Return subentries supported by this integration."""
return {"conversation": ConversationSubentryFlowHandler} return {
"conversation": LLMSubentryFlowHandler,
"ai_task": LLMSubentryFlowHandler,
}
class ConversationSubentryFlowHandler(ConfigSubentryFlow): class LLMSubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing conversation subentries.""" """Flow for managing conversation subentries."""
last_rendered_recommended = False last_rendered_recommended = False
@@ -158,7 +164,10 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
) -> SubentryFlowResult: ) -> SubentryFlowResult:
"""Add a subentry.""" """Add a subentry."""
self.is_new = True self.is_new = True
self.options = RECOMMENDED_OPTIONS.copy() if self._subentry_type == "ai_task":
self.options = RECOMMENDED_AI_TASK_OPTIONS.copy()
else:
self.options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
return await self.async_step_init() return await self.async_step_init()
async def async_step_reconfigure( async def async_step_reconfigure(
@@ -190,10 +199,15 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
step_schema: VolDictType = {} step_schema: VolDictType = {}
if self.is_new: if self.is_new:
step_schema[vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME)] = ( if CONF_NAME in options:
str default_name = options[CONF_NAME]
) elif self._subentry_type == "ai_task":
default_name = DEFAULT_AI_TASK_NAME
else:
default_name = DEFAULT_CONVERSATION_NAME
step_schema[vol.Required(CONF_NAME, default=default_name)] = str
if self._subentry_type == "conversation":
step_schema.update( step_schema.update(
{ {
vol.Optional( vol.Optional(
@@ -207,12 +221,13 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
vol.Optional(CONF_LLM_HASS_API): SelectSelector( vol.Optional(CONF_LLM_HASS_API): SelectSelector(
SelectSelectorConfig(options=hass_apis, multiple=True) SelectSelectorConfig(options=hass_apis, multiple=True)
), ),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
} }
) )
step_schema[
vol.Required(CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False))
] = bool
if user_input is not None: if user_input is not None:
if not user_input.get(CONF_LLM_HASS_API): if not user_input.get(CONF_LLM_HASS_API):
user_input.pop(CONF_LLM_HASS_API, None) user_input.pop(CONF_LLM_HASS_API, None)

View File

@@ -2,10 +2,14 @@
import logging import logging
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.helpers import llm
DOMAIN = "openai_conversation" DOMAIN = "openai_conversation"
LOGGER: logging.Logger = logging.getLogger(__package__) LOGGER: logging.Logger = logging.getLogger(__package__)
DEFAULT_CONVERSATION_NAME = "OpenAI Conversation" DEFAULT_CONVERSATION_NAME = "OpenAI Conversation"
DEFAULT_AI_TASK_NAME = "OpenAI AI Task"
CONF_CHAT_MODEL = "chat_model" CONF_CHAT_MODEL = "chat_model"
CONF_FILENAMES = "filenames" CONF_FILENAMES = "filenames"
@@ -32,6 +36,16 @@ RECOMMENDED_WEB_SEARCH = False
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE = "medium" RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE = "medium"
RECOMMENDED_WEB_SEARCH_USER_LOCATION = False RECOMMENDED_WEB_SEARCH_USER_LOCATION = False
RECOMMENDED_CONVERSATION_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
}
RECOMMENDED_AI_TASK_OPTIONS = {
CONF_RECOMMENDED: True,
}
UNSUPPORTED_MODELS: list[str] = [ UNSUPPORTED_MODELS: list[str] = [
"o1-mini", "o1-mini",
"o1-mini-2024-09-12", "o1-mini-2024-09-12",

View File

@@ -1,74 +1,17 @@
"""Conversation support for OpenAI.""" """Conversation support for OpenAI."""
from collections.abc import AsyncGenerator, Callable from typing import Literal
import json
from typing import Any, Literal, cast
import openai
from openai._streaming import AsyncStream
from openai.types.responses import (
EasyInputMessageParam,
FunctionToolParam,
ResponseCompletedEvent,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionToolCallParam,
ResponseIncompleteEvent,
ResponseInputParam,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputMessageParam,
ResponseReasoningItem,
ResponseReasoningItemParam,
ResponseStreamEvent,
ResponseTextDeltaEvent,
ToolParam,
WebSearchToolParam,
)
from openai.types.responses.response_input_param import FunctionCallOutput
from openai.types.responses.web_search_tool_param import UserLocation
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import intent
from homeassistant.helpers import device_registry as dr, intent, llm
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import OpenAIConfigEntry from . import OpenAIConfigEntry
from .const import ( from .const import CONF_PROMPT, DEFAULT_CONVERSATION_NAME, DOMAIN
CONF_CHAT_MODEL, from .entity import OpenAILLMBaseEntity
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_REASONING_EFFORT,
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_CITY,
CONF_WEB_SEARCH_CONTEXT_SIZE,
CONF_WEB_SEARCH_COUNTRY,
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT_CONVERSATION_NAME,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
)
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
async def async_setup_entry( async def async_setup_entry(
@@ -87,152 +30,10 @@ async def async_setup_entry(
) )
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> FunctionToolParam:
"""Format tool specification."""
return FunctionToolParam(
type="function",
name=tool.name,
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
description=tool.description,
strict=False,
)
def _convert_content_to_param(
content: conversation.Content,
) -> ResponseInputParam:
"""Convert any native chat message for this agent to the native format."""
messages: ResponseInputParam = []
if isinstance(content, conversation.ToolResultContent):
return [
FunctionCallOutput(
type="function_call_output",
call_id=content.tool_call_id,
output=json.dumps(content.tool_result),
)
]
if content.content:
role: Literal["user", "assistant", "system", "developer"] = content.role
if role == "system":
role = "developer"
messages.append(
EasyInputMessageParam(type="message", role=role, content=content.content)
)
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
messages.extend(
ResponseFunctionToolCallParam(
type="function_call",
name=tool_call.tool_name,
arguments=json.dumps(tool_call.tool_args),
call_id=tool_call.id,
)
for tool_call in content.tool_calls
)
return messages
async def _transform_stream(
chat_log: conversation.ChatLog,
result: AsyncStream[ResponseStreamEvent],
messages: ResponseInputParam,
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform an OpenAI delta stream into HA format."""
async for event in result:
LOGGER.debug("Received event: %s", event)
if isinstance(event, ResponseOutputItemAddedEvent):
if isinstance(event.item, ResponseOutputMessage):
yield {"role": event.item.role}
elif isinstance(event.item, ResponseFunctionToolCall):
# OpenAI has tool calls as individual events
# while HA puts tool calls inside the assistant message.
# We turn them into individual assistant content for HA
# to ensure that tools are called as soon as possible.
yield {"role": "assistant"}
current_tool_call = event.item
elif isinstance(event, ResponseOutputItemDoneEvent):
item = event.item.model_dump()
item.pop("status", None)
if isinstance(event.item, ResponseReasoningItem):
messages.append(cast(ResponseReasoningItemParam, item))
elif isinstance(event.item, ResponseOutputMessage):
messages.append(cast(ResponseOutputMessageParam, item))
elif isinstance(event.item, ResponseFunctionToolCall):
messages.append(cast(ResponseFunctionToolCallParam, item))
elif isinstance(event, ResponseTextDeltaEvent):
yield {"content": event.delta}
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
current_tool_call.arguments += event.delta
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
current_tool_call.status = "completed"
yield {
"tool_calls": [
llm.ToolInput(
id=current_tool_call.call_id,
tool_name=current_tool_call.name,
tool_args=json.loads(current_tool_call.arguments),
)
]
}
elif isinstance(event, ResponseCompletedEvent):
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
elif isinstance(event, ResponseIncompleteEvent):
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
if (
event.response.incomplete_details
and event.response.incomplete_details.reason
):
reason: str = event.response.incomplete_details.reason
else:
reason = "unknown reason"
if reason == "max_output_tokens":
reason = "max output tokens reached"
elif reason == "content_filter":
reason = "content filter triggered"
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
elif isinstance(event, ResponseFailedEvent):
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
reason = "unknown reason"
if event.response.error is not None:
reason = event.response.error.message
raise HomeAssistantError(f"OpenAI response failed: {reason}")
elif isinstance(event, ResponseErrorEvent):
raise HomeAssistantError(f"OpenAI response error: {event.message}")
class OpenAIConversationEntity( class OpenAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent conversation.ConversationEntity,
conversation.AbstractConversationAgent,
OpenAILLMBaseEntity,
): ):
"""OpenAI conversation agent.""" """OpenAI conversation agent."""
@@ -240,17 +41,8 @@ class OpenAIConversationEntity(
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None: def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the agent.""" """Initialize the agent."""
self.entry = entry super().__init__(entry, subentry)
self.subentry = subentry
self._attr_name = subentry.title or DEFAULT_CONVERSATION_NAME self._attr_name = subentry.title or DEFAULT_CONVERSATION_NAME
self._attr_unique_id = subentry.subentry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)},
name=entry.title,
manufacturer="OpenAI",
model="ChatGPT",
entry_type=dr.DeviceEntryType.SERVICE,
)
if self.subentry.data.get(CONF_LLM_HASS_API): if self.subentry.data.get(CONF_LLM_HASS_API):
self._attr_supported_features = ( self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL conversation.ConversationEntityFeature.CONTROL
@@ -306,91 +98,6 @@ class OpenAIConversationEntity(
continue_conversation=chat_log.continue_conversation, continue_conversation=chat_log.continue_conversation,
) )
async def _async_handle_chat_log(
self,
chat_log: conversation.ChatLog,
) -> None:
"""Generate an answer for the chat log."""
options = self.subentry.data
tools: list[ToolParam] | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
if options.get(CONF_WEB_SEARCH):
web_search = WebSearchToolParam(
type="web_search_preview",
search_context_size=options.get(
CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
),
)
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
web_search["user_location"] = UserLocation(
type="approximate",
city=options.get(CONF_WEB_SEARCH_CITY, ""),
region=options.get(CONF_WEB_SEARCH_REGION, ""),
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
)
if tools is None:
tools = []
tools.append(web_search)
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
messages = [
m
for content in chat_log.content
for m in _convert_content_to_param(content)
]
client = self.entry.runtime_data
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
model_args = {
"model": model,
"input": messages,
"max_output_tokens": options.get(
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
),
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": chat_log.conversation_id,
"stream": True,
}
if tools:
model_args["tools"] = tools
if model.startswith("o"):
model_args["reasoning"] = {
"effort": options.get(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
)
}
else:
model_args["store"] = False
try:
result = await client.responses.create(**model_args)
except openai.RateLimitError as err:
LOGGER.error("Rate limited by OpenAI: %s", err)
raise HomeAssistantError("Rate limited or insufficient funds") from err
except openai.OpenAIError as err:
LOGGER.error("Error talking to OpenAI: %s", err)
raise HomeAssistantError("Error talking to OpenAI") from err
async for content in chat_log.async_add_delta_content_stream(
self.entity_id, _transform_stream(chat_log, result, messages)
):
if not isinstance(content, conversation.AssistantContent):
messages.extend(_convert_content_to_param(content))
if not chat_log.unresponded_tool_results:
break
async def _async_entry_update_listener( async def _async_entry_update_listener(
self, hass: HomeAssistant, entry: ConfigEntry self, hass: HomeAssistant, entry: ConfigEntry
) -> None: ) -> None:

View File

@@ -0,0 +1,313 @@
"""Base class for OpenAI Conversation entities."""
from collections.abc import AsyncGenerator, Callable
import json
from typing import Any, Literal, cast
import openai
from openai._streaming import AsyncStream
from openai.types.responses import (
EasyInputMessageParam,
FunctionToolParam,
ResponseCompletedEvent,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionToolCallParam,
ResponseIncompleteEvent,
ResponseInputParam,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputMessageParam,
ResponseReasoningItem,
ResponseReasoningItemParam,
ResponseStreamEvent,
ResponseTextDeltaEvent,
ToolParam,
WebSearchToolParam,
)
from openai.types.responses.response_input_param import FunctionCallOutput
from openai.types.responses.web_search_tool_param import UserLocation
from voluptuous_openapi import convert
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm
from homeassistant.helpers.entity import Entity
from . import OpenAIConfigEntry
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_REASONING_EFFORT,
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_CITY,
CONF_WEB_SEARCH_CONTEXT_SIZE,
CONF_WEB_SEARCH_COUNTRY,
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
)
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> FunctionToolParam:
"""Format tool specification."""
return FunctionToolParam(
type="function",
name=tool.name,
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
description=tool.description,
strict=False,
)
def _convert_content_to_param(
content: conversation.Content,
) -> ResponseInputParam:
"""Convert any native chat message for this agent to the native format."""
messages: ResponseInputParam = []
if isinstance(content, conversation.ToolResultContent):
return [
FunctionCallOutput(
type="function_call_output",
call_id=content.tool_call_id,
output=json.dumps(content.tool_result),
)
]
if content.content:
role: Literal["user", "assistant", "system", "developer"] = content.role
if role == "system":
role = "developer"
messages.append(
EasyInputMessageParam(type="message", role=role, content=content.content)
)
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
messages.extend(
ResponseFunctionToolCallParam(
type="function_call",
name=tool_call.tool_name,
arguments=json.dumps(tool_call.tool_args),
call_id=tool_call.id,
)
for tool_call in content.tool_calls
)
return messages
async def _transform_stream(
chat_log: conversation.ChatLog,
result: AsyncStream[ResponseStreamEvent],
messages: ResponseInputParam,
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform an OpenAI delta stream into HA format."""
async for event in result:
LOGGER.debug("Received event: %s", event)
if isinstance(event, ResponseOutputItemAddedEvent):
if isinstance(event.item, ResponseOutputMessage):
yield {"role": event.item.role}
elif isinstance(event.item, ResponseFunctionToolCall):
# OpenAI has tool calls as individual events
# while HA puts tool calls inside the assistant message.
# We turn them into individual assistant content for HA
# to ensure that tools are called as soon as possible.
yield {"role": "assistant"}
current_tool_call = event.item
elif isinstance(event, ResponseOutputItemDoneEvent):
item = event.item.model_dump()
item.pop("status", None)
if isinstance(event.item, ResponseReasoningItem):
messages.append(cast(ResponseReasoningItemParam, item))
elif isinstance(event.item, ResponseOutputMessage):
messages.append(cast(ResponseOutputMessageParam, item))
elif isinstance(event.item, ResponseFunctionToolCall):
messages.append(cast(ResponseFunctionToolCallParam, item))
elif isinstance(event, ResponseTextDeltaEvent):
yield {"content": event.delta}
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
current_tool_call.arguments += event.delta
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
current_tool_call.status = "completed"
yield {
"tool_calls": [
llm.ToolInput(
id=current_tool_call.call_id,
tool_name=current_tool_call.name,
tool_args=json.loads(current_tool_call.arguments),
)
]
}
elif isinstance(event, ResponseCompletedEvent):
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
elif isinstance(event, ResponseIncompleteEvent):
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
if (
event.response.incomplete_details
and event.response.incomplete_details.reason
):
reason: str = event.response.incomplete_details.reason
else:
reason = "unknown reason"
if reason == "max_output_tokens":
reason = "max output tokens reached"
elif reason == "content_filter":
reason = "content filter triggered"
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
elif isinstance(event, ResponseFailedEvent):
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
reason = "unknown reason"
if event.response.error is not None:
reason = event.response.error.message
raise HomeAssistantError(f"OpenAI response failed: {reason}")
elif isinstance(event, ResponseErrorEvent):
raise HomeAssistantError(f"OpenAI response error: {event.message}")
class OpenAILLMBaseEntity(Entity):
"""OpenAI conversation agent."""
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the agent."""
self.entry = entry
self.subentry = subentry
self._attr_unique_id = subentry.subentry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)},
name=entry.title,
manufacturer="OpenAI",
model="ChatGPT",
entry_type=dr.DeviceEntryType.SERVICE,
)
async def _async_handle_chat_log(
self,
chat_log: conversation.ChatLog,
) -> None:
"""Generate an answer for the chat log."""
options = self.subentry.data
tools: list[ToolParam] | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
if options.get(CONF_WEB_SEARCH):
web_search = WebSearchToolParam(
type="web_search_preview",
search_context_size=options.get(
CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
),
)
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
web_search["user_location"] = UserLocation(
type="approximate",
city=options.get(CONF_WEB_SEARCH_CITY, ""),
region=options.get(CONF_WEB_SEARCH_REGION, ""),
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
)
if tools is None:
tools = []
tools.append(web_search)
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
messages = [
m
for content in chat_log.content
for m in _convert_content_to_param(content)
]
client = self.entry.runtime_data
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
model_args = {
"model": model,
"input": messages,
"max_output_tokens": options.get(
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
),
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": chat_log.conversation_id,
"stream": True,
}
if tools:
model_args["tools"] = tools
if model.startswith("o"):
model_args["reasoning"] = {
"effort": options.get(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
)
}
else:
model_args["store"] = False
try:
result = await client.responses.create(**model_args)
except openai.RateLimitError as err:
LOGGER.error("Rate limited by OpenAI: %s", err)
raise HomeAssistantError("Rate limited or insufficient funds") from err
except openai.OpenAIError as err:
LOGGER.error("Error talking to OpenAI: %s", err)
raise HomeAssistantError("Error talking to OpenAI") from err
async for content in chat_log.async_add_delta_content_stream(
self.entity_id, _transform_stream(chat_log, result, messages)
):
if not isinstance(content, conversation.AssistantContent):
messages.extend(_convert_content_to_param(content))
if not chat_log.unresponded_tool_results:
break

View File

@@ -64,6 +64,51 @@
"error": { "error": {
"model_not_supported": "This model is not supported, please select a different model" "model_not_supported": "This model is not supported, please select a different model"
} }
},
"ai_task": {
"initiate_flow": {
"user": "Add AI task service",
"reconfigure": "Reconfigure AI task service"
},
"entry_type": "AI task service",
"step": {
"init": {
"data": {
"name": "[%key:common::config_flow::data::name%]",
"recommended": "[%key:component::openai_conversation::config_subentries::conversation::step::init::data::recommended%]"
}
},
"advanced": {
"title": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::title%]",
"data": {
"chat_model": "[%key:common::generic::model%]",
"max_tokens": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::max_tokens%]",
"temperature": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::temperature%]",
"top_p": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::top_p%]"
}
},
"model": {
"title": "[%key:component::openai_conversation::config_subentries::conversation::step::model::title%]",
"data": {
"reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::reasoning_effort%]",
"web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::web_search%]",
"search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::search_context_size%]",
"user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::user_location%]"
},
"data_description": {
"reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::reasoning_effort%]",
"web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::web_search%]",
"search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::search_context_size%]",
"user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::user_location%]"
}
}
},
"abort": {
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
},
"error": {
"model_not_supported": "[%key:component::openai_conversation::config_subentries::conversation::error::model_not_supported%]"
}
} }
}, },
"selector": { "selector": {

View File

@@ -3420,6 +3420,11 @@ class ConfigSubentryFlow(
"""Return config entry id.""" """Return config entry id."""
return self.handler[0] return self.handler[0]
@property
def _subentry_type(self) -> str:
"""Return type of subentry we are editing/creating."""
return self.handler[1]
@callback @callback
def _get_entry(self) -> ConfigEntry: def _get_entry(self) -> ConfigEntry:
"""Return the config entry linked to the current context.""" """Return the config entry linked to the current context."""

View File

@@ -0,0 +1,87 @@
"""Common utilities for OpenAI conversation tests."""
from openai.types.responses import (
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputText,
ResponseStreamEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
)
def create_message_item(
id: str, text: str | list[str], output_index: int
) -> list[ResponseStreamEvent]:
"""Create a message item."""
if isinstance(text, str):
text = [text]
content = ResponseOutputText(annotations=[], text="", type="output_text")
events = [
ResponseOutputItemAddedEvent(
item=ResponseOutputMessage(
id=id,
content=[],
type="message",
role="assistant",
status="in_progress",
),
output_index=output_index,
type="response.output_item.added",
),
ResponseContentPartAddedEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
type="response.content_part.added",
),
]
content.text = "".join(text)
events.extend(
ResponseTextDeltaEvent(
content_index=0,
delta=delta,
item_id=id,
output_index=output_index,
type="response.output_text.delta",
)
for delta in text
)
events.extend(
[
ResponseTextDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
text="".join(text),
type="response.output_text.done",
),
ResponseContentPartDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
type="response.content_part.done",
),
ResponseOutputItemDoneEvent(
item=ResponseOutputMessage(
id=id,
content=[content],
role="assistant",
status="completed",
type="message",
),
output_index=output_index,
type="response.output_item.done",
),
]
)
return events

View File

@@ -1,16 +1,35 @@
"""Tests helpers.""" """Tests helpers."""
from unittest.mock import patch from collections.abc import Generator
from unittest.mock import AsyncMock, patch
from openai.types import ResponseFormatText
from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseCreatedEvent,
ResponseError,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseOutputItemDoneEvent,
ResponseTextConfig,
)
from openai.types.responses.response import IncompleteDetails
import pytest import pytest
from homeassistant.components.openai_conversation.const import DEFAULT_CONVERSATION_NAME from homeassistant.components.openai_conversation.const import (
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
)
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm from homeassistant.helpers import llm
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.components.conversation import mock_chat_log # noqa: F401
@pytest.fixture @pytest.fixture
@@ -29,7 +48,13 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"subentry_type": "conversation", "subentry_type": "conversation",
"title": DEFAULT_CONVERSATION_NAME, "title": DEFAULT_CONVERSATION_NAME,
"unique_id": None, "unique_id": None,
} },
{
"data": {},
"subentry_type": "ai_task",
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
], ],
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
@@ -65,3 +90,89 @@ async def mock_init_component(
async def setup_ha(hass: HomeAssistant) -> None: async def setup_ha(hass: HomeAssistant) -> None:
"""Set up Home Assistant.""" """Set up Home Assistant."""
assert await async_setup_component(hass, "homeassistant", {}) assert await async_setup_component(hass, "homeassistant", {})
@pytest.fixture
def mock_create_stream() -> Generator[AsyncMock]:
"""Mock stream response."""
async def mock_generator(events, **kwargs):
response = Response(
id="resp_A",
created_at=1700000000,
error=None,
incomplete_details=None,
instructions=kwargs.get("instructions"),
metadata=kwargs.get("metadata", {}),
model=kwargs.get("model", "gpt-4o-mini"),
object="response",
output=[],
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
temperature=kwargs.get("temperature", 1.0),
tool_choice=kwargs.get("tool_choice", "auto"),
tools=kwargs.get("tools", []),
top_p=kwargs.get("top_p", 1.0),
max_output_tokens=kwargs.get("max_output_tokens", 100000),
previous_response_id=kwargs.get("previous_response_id"),
reasoning=kwargs.get("reasoning"),
status="in_progress",
text=kwargs.get(
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
),
truncation=kwargs.get("truncation", "disabled"),
usage=None,
user=kwargs.get("user"),
store=kwargs.get("store", True),
)
yield ResponseCreatedEvent(
response=response,
type="response.created",
)
yield ResponseInProgressEvent(
response=response,
type="response.in_progress",
)
response.status = "completed"
for value in events:
if isinstance(value, ResponseOutputItemDoneEvent):
response.output.append(value.item)
elif isinstance(value, IncompleteDetails):
response.status = "incomplete"
response.incomplete_details = value
break
if isinstance(value, ResponseError):
response.status = "failed"
response.error = value
break
yield value
if isinstance(value, ResponseErrorEvent):
return
if response.status == "incomplete":
yield ResponseIncompleteEvent(
response=response,
type="response.incomplete",
)
elif response.status == "failed":
yield ResponseFailedEvent(
response=response,
type="response.failed",
)
else:
yield ResponseCompletedEvent(
response=response,
type="response.completed",
)
with patch(
"openai.resources.responses.AsyncResponses.create",
AsyncMock(),
) as mock_create:
mock_create.side_effect = lambda **kwargs: mock_generator(
mock_create.return_value.pop(0), **kwargs
)
yield mock_create

View File

@@ -0,0 +1,33 @@
"""Test AI Task platform of OpenAI Conversation integration."""
from unittest.mock import AsyncMock
import pytest
from homeassistant.components import ai_task
from homeassistant.core import HomeAssistant
from .common import create_message_item
from tests.common import MockConfigEntry
@pytest.mark.usefixtures("mock_init_component")
async def test_ai_task_generate_text(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_create_stream: AsyncMock,
) -> None:
"""Test that AI task can generate text."""
entity_id = "ai_task.openai_ai_task"
mock_create_stream.return_value = [
create_message_item(id="msg_A", text="Hi there!", output_index=0)
]
result = await ai_task.async_generate_text(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Test prompt",
)
assert result.text == "Hi there!"

View File

@@ -8,7 +8,6 @@ from openai.types.responses import Response, ResponseOutputMessage, ResponseOutp
import pytest import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.openai_conversation.config_flow import RECOMMENDED_OPTIONS
from homeassistant.components.openai_conversation.const import ( from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
@@ -24,9 +23,12 @@ from homeassistant.components.openai_conversation.const import (
CONF_WEB_SEARCH_REGION, CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE, CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION, CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DOMAIN, DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_MAX_TOKENS, RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TOP_P, RECOMMENDED_TOP_P,
) )
@@ -77,10 +79,16 @@ async def test_form(hass: HomeAssistant) -> None:
assert result2["subentries"] == [ assert result2["subentries"] == [
{ {
"subentry_type": "conversation", "subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS, "data": RECOMMENDED_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME, "title": DEFAULT_CONVERSATION_NAME,
"unique_id": None, "unique_id": None,
} },
{
"subentry_type": "ai_task",
"data": RECOMMENDED_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
] ]
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
@@ -104,19 +112,56 @@ async def test_creating_conversation_subentry(
result2 = await hass.config_entries.subentries.async_configure( result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"], result["flow_id"],
{"name": "My Custom Agent", **RECOMMENDED_OPTIONS}, {"name": "My Custom Agent", **RECOMMENDED_CONVERSATION_OPTIONS},
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "My Custom Agent" assert result2["title"] == "My Custom Agent"
processed_options = RECOMMENDED_OPTIONS.copy() processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip() processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
assert result2["data"] == processed_options assert result2["data"] == processed_options
async def test_creating_ai_task_subentry(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test creating an AI task subentry."""
with patch("openai.resources.models.AsyncModels.list"):
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task"),
context={"source": config_entries.SOURCE_USER},
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "init"
assert not result["errors"]
old_subentries = set(mock_config_entry.subentries)
with patch("openai.resources.models.AsyncModels.list"):
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{"name": "My AI Task", **RECOMMENDED_AI_TASK_OPTIONS},
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "My AI Task"
assert result2["data"] == RECOMMENDED_AI_TASK_OPTIONS
assert len(mock_config_entry.subentries) == 3
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id]
assert new_subentry.subentry_type == "ai_task"
assert new_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
assert new_subentry.title == "My AI Task"
async def test_subentry_recommended( async def test_subentry_recommended(
hass: HomeAssistant, mock_config_entry, mock_init_component hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None: ) -> None:

View File

@@ -1,35 +1,20 @@
"""Tests for the OpenAI integration.""" """Tests for the OpenAI integration."""
from collections.abc import Generator
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import httpx import httpx
from openai import AuthenticationError, RateLimitError from openai import AuthenticationError, RateLimitError
from openai.types import ResponseFormatText
from openai.types.responses import ( from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseError, ResponseError,
ResponseErrorEvent, ResponseErrorEvent,
ResponseFailedEvent,
ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent, ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall, ResponseFunctionToolCall,
ResponseFunctionWebSearch, ResponseFunctionWebSearch,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent, ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent, ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem, ResponseReasoningItem,
ResponseStreamEvent, ResponseStreamEvent,
ResponseTextConfig,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseWebSearchCallCompletedEvent, ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent, ResponseWebSearchCallSearchingEvent,
@@ -54,6 +39,8 @@ from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .common import create_message_item
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.components.conversation import ( from tests.components.conversation import (
MockChatLog, MockChatLog,
@@ -61,92 +48,6 @@ from tests.components.conversation import (
) )
@pytest.fixture
def mock_create_stream() -> Generator[AsyncMock]:
"""Mock stream response."""
async def mock_generator(events, **kwargs):
response = Response(
id="resp_A",
created_at=1700000000,
error=None,
incomplete_details=None,
instructions=kwargs.get("instructions"),
metadata=kwargs.get("metadata", {}),
model=kwargs.get("model", "gpt-4o-mini"),
object="response",
output=[],
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
temperature=kwargs.get("temperature", 1.0),
tool_choice=kwargs.get("tool_choice", "auto"),
tools=kwargs.get("tools"),
top_p=kwargs.get("top_p", 1.0),
max_output_tokens=kwargs.get("max_output_tokens", 100000),
previous_response_id=kwargs.get("previous_response_id"),
reasoning=kwargs.get("reasoning"),
status="in_progress",
text=kwargs.get(
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
),
truncation=kwargs.get("truncation", "disabled"),
usage=None,
user=kwargs.get("user"),
store=kwargs.get("store", True),
)
yield ResponseCreatedEvent(
response=response,
type="response.created",
)
yield ResponseInProgressEvent(
response=response,
type="response.in_progress",
)
response.status = "completed"
for value in events:
if isinstance(value, ResponseOutputItemDoneEvent):
response.output.append(value.item)
elif isinstance(value, IncompleteDetails):
response.status = "incomplete"
response.incomplete_details = value
break
if isinstance(value, ResponseError):
response.status = "failed"
response.error = value
break
yield value
if isinstance(value, ResponseErrorEvent):
return
if response.status == "incomplete":
yield ResponseIncompleteEvent(
response=response,
type="response.incomplete",
)
elif response.status == "failed":
yield ResponseFailedEvent(
response=response,
type="response.failed",
)
else:
yield ResponseCompletedEvent(
response=response,
type="response.completed",
)
with patch(
"openai.resources.responses.AsyncResponses.create",
AsyncMock(),
) as mock_create:
mock_create.side_effect = lambda **kwargs: mock_generator(
mock_create.return_value.pop(0), **kwargs
)
yield mock_create
async def test_entity( async def test_entity(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
@@ -341,80 +242,6 @@ async def test_conversation_agent(
assert agent.supported_languages == "*" assert agent.supported_languages == "*"
def create_message_item(
id: str, text: str | list[str], output_index: int
) -> list[ResponseStreamEvent]:
"""Create a message item."""
if isinstance(text, str):
text = [text]
content = ResponseOutputText(annotations=[], text="", type="output_text")
events = [
ResponseOutputItemAddedEvent(
item=ResponseOutputMessage(
id=id,
content=[],
type="message",
role="assistant",
status="in_progress",
),
output_index=output_index,
type="response.output_item.added",
),
ResponseContentPartAddedEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
type="response.content_part.added",
),
]
content.text = "".join(text)
events.extend(
ResponseTextDeltaEvent(
content_index=0,
delta=delta,
item_id=id,
output_index=output_index,
type="response.output_text.delta",
)
for delta in text
)
events.extend(
[
ResponseTextDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
text="".join(text),
type="response.output_text.done",
),
ResponseContentPartDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
type="response.content_part.done",
),
ResponseOutputItemDoneEvent(
item=ResponseOutputMessage(
id=id,
content=[content],
role="assistant",
status="completed",
type="message",
),
output_index=output_index,
type="response.output_item.done",
),
]
)
return events
def create_function_tool_call_item( def create_function_tool_call_item(
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
) -> list[ResponseStreamEvent]: ) -> list[ResponseStreamEvent]: