Compare commits

...

3 Commits

Author SHA1 Message Date
Paulus Schoutsen
2029996d19 Add AI Task platform to OpenAI 2025-06-22 04:02:26 +00:00
Paulus Schoutsen
3d5df43ebb Whitespace 2025-06-22 00:28:28 +00:00
Paulus Schoutsen
1d5532d00d Migrate OpenAI to config subentries 2025-06-22 00:26:07 +00:00
14 changed files with 1130 additions and 641 deletions

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import base64
from mimetypes import guess_file_type
from pathlib import Path
from types import MappingProxyType
import openai
from openai.types.images_response import ImagesResponse
@@ -19,7 +20,7 @@ from openai.types.responses import (
)
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_API_KEY, Platform
from homeassistant.core import (
HomeAssistant,
@@ -32,7 +33,11 @@ from homeassistant.exceptions import (
HomeAssistantError,
ServiceValidationError,
)
from homeassistant.helpers import config_validation as cv, selector
from homeassistant.helpers import (
config_validation as cv,
entity_registry as er,
selector,
)
from homeassistant.helpers.httpx_client import get_async_client
from homeassistant.helpers.typing import ConfigType
@@ -44,8 +49,11 @@ from .const import (
CONF_REASONING_EFFORT,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DOMAIN,
LOGGER,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
@@ -56,7 +64,10 @@ from .const import (
SERVICE_GENERATE_IMAGE = "generate_image"
SERVICE_GENERATE_CONTENT = "generate_content"
PLATFORMS = (Platform.CONVERSATION,)
PLATFORMS = (
Platform.AI_TASK,
Platform.CONVERSATION,
)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
@@ -118,7 +129,21 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
translation_placeholders={"config_entry": entry_id},
)
model: str = entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
# Get first conversation subentry for options
conversation_subentry = next(
(
sub
for sub in entry.subentries.values()
if sub.subentry_type == "conversation"
),
None,
)
if not conversation_subentry:
raise HomeAssistantError("No conversation configuration found")
model: str = conversation_subentry.data.get(
CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL
)
client: openai.AsyncClient = entry.runtime_data
content: ResponseInputMessageContentListParam = [
@@ -169,11 +194,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
model_args = {
"model": model,
"input": messages,
"max_output_tokens": entry.options.get(
"max_output_tokens": conversation_subentry.data.get(
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
),
"top_p": entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": entry.options.get(
"top_p": conversation_subentry.data.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": conversation_subentry.data.get(
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
),
"user": call.context.user_id,
@@ -182,7 +207,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
if model.startswith("o"):
model_args["reasoning"] = {
"effort": entry.options.get(
"effort": conversation_subentry.data.get(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
)
}
@@ -269,3 +294,49 @@ async def async_setup_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> bo
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload OpenAI."""
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> bool:
"""Migrate old entry."""
if entry.version == 1:
# Migrate from version 1 to version 2
# Move conversation-specific options to a subentry
conversation_subentry = ConfigSubentry(
data=entry.options,
subentry_type="conversation",
title=DEFAULT_CONVERSATION_NAME,
unique_id=None,
)
hass.config_entries.async_add_subentry(
entry,
conversation_subentry,
)
hass.config_entries.async_add_subentry(
entry,
ConfigSubentry(
data=MappingProxyType(RECOMMENDED_AI_TASK_OPTIONS),
subentry_type="ai_task",
title=DEFAULT_AI_TASK_NAME,
unique_id=None,
),
)
# Migrate conversation entity to be linked to subentry
ent_reg = er.async_get(hass)
for entity_entry in er.async_entries_for_config_entry(ent_reg, entry.entry_id):
if entity_entry.domain == Platform.CONVERSATION:
ent_reg.async_update_entity(
entity_entry.entity_id,
config_subentry_id=conversation_subentry.subentry_id,
new_unique_id=conversation_subentry.subentry_id,
)
break
# Remove options from the main entry
hass.config_entries.async_update_entry(
entry,
options={},
version=2,
)
return True

View File

@@ -0,0 +1,62 @@
"""AI Task integration for OpenAI Conversation."""
from __future__ import annotations
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import OpenAIConfigEntry
from .const import DEFAULT_AI_TASK_NAME, LOGGER
from .entity import OpenAILLMBaseEntity
ERROR_GETTING_RESPONSE = "Sorry, I had a problem getting a response from OpenAI."
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task":
continue
async_add_entities(
[OpenAILLMTaskEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
class OpenAILLMTaskEntity(ai_task.AITaskEntity, OpenAILLMBaseEntity):
"""OpenAI AI Task entity."""
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_TEXT
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the agent."""
super().__init__(entry, subentry)
self._attr_name = subentry.title or DEFAULT_AI_TASK_NAME
async def _async_generate_text(
self,
task: ai_task.GenTextTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenTextTaskResult:
"""Handle a generate text task."""
await self._async_handle_chat_log(chat_log)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
LOGGER.error(
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
chat_log.content[-1],
)
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
return ai_task.GenTextTaskResult(
conversation_id=chat_log.conversation_id,
text=chat_log.content[-1].content or "",
)

View File

@@ -15,15 +15,17 @@ from homeassistant.config_entries import (
ConfigEntry,
ConfigFlow,
ConfigFlowResult,
OptionsFlow,
ConfigSubentryFlow,
SubentryFlowResult,
)
from homeassistant.const import (
ATTR_LATITUDE,
ATTR_LONGITUDE,
CONF_API_KEY,
CONF_LLM_HASS_API,
CONF_NAME,
)
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import llm
from homeassistant.helpers.httpx_client import get_async_client
from homeassistant.helpers.selector import (
@@ -52,8 +54,12 @@ from .const import (
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE,
@@ -73,12 +79,6 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
}
)
RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
}
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect.
@@ -94,7 +94,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for OpenAI Conversation."""
VERSION = 1
VERSION = 2
async def async_step_user(
self, user_input: dict[str, Any] | None = None
@@ -120,31 +120,67 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_create_entry(
title="ChatGPT",
data=user_input,
options=RECOMMENDED_OPTIONS,
subentries=[
{
"subentry_type": "conversation",
"data": RECOMMENDED_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
},
{
"subentry_type": "ai_task",
"data": RECOMMENDED_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
],
)
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
)
@staticmethod
def async_get_options_flow(
config_entry: ConfigEntry,
) -> OptionsFlow:
"""Create the options flow."""
return OpenAIOptionsFlow(config_entry)
@classmethod
@callback
def async_get_supported_subentry_types(
cls, config_entry: ConfigEntry
) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration."""
return {
"conversation": LLMSubentryFlowHandler,
"ai_task": LLMSubentryFlowHandler,
}
class OpenAIOptionsFlow(OptionsFlow):
"""OpenAI config flow options handler."""
class LLMSubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing conversation subentries."""
def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow."""
self.options = config_entry.options.copy()
last_rendered_recommended = False
is_new: bool
options: dict[str, Any]
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Add a subentry."""
self.is_new = True
if self._subentry_type == "ai_task":
self.options = RECOMMENDED_AI_TASK_OPTIONS.copy()
else:
self.options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
return await self.async_step_init()
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Handle reconfiguration of a subentry."""
self.is_new = False
self.options = self._get_reconfigure_subentry().data.copy()
return await self.async_step_init()
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
) -> SubentryFlowResult:
"""Manage initial options."""
options = self.options
@@ -160,25 +196,53 @@ class OpenAIOptionsFlow(OptionsFlow):
):
options[CONF_LLM_HASS_API] = [suggested_llm_apis]
step_schema: VolDictType = {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": llm.DEFAULT_INSTRUCTIONS_PROMPT},
): TemplateSelector(),
vol.Optional(CONF_LLM_HASS_API): SelectSelector(
SelectSelectorConfig(options=hass_apis, multiple=True)
),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
}
step_schema: VolDictType = {}
if self.is_new:
if CONF_NAME in options:
default_name = options[CONF_NAME]
elif self._subentry_type == "ai_task":
default_name = DEFAULT_AI_TASK_NAME
else:
default_name = DEFAULT_CONVERSATION_NAME
step_schema[vol.Required(CONF_NAME, default=default_name)] = str
if self._subentry_type == "conversation":
step_schema.update(
{
vol.Optional(
CONF_PROMPT,
description={
"suggested_value": options.get(
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
)
},
): TemplateSelector(),
vol.Optional(CONF_LLM_HASS_API): SelectSelector(
SelectSelectorConfig(options=hass_apis, multiple=True)
),
}
)
step_schema[
vol.Required(CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False))
] = bool
if user_input is not None:
if not user_input.get(CONF_LLM_HASS_API):
user_input.pop(CONF_LLM_HASS_API, None)
if user_input[CONF_RECOMMENDED]:
return self.async_create_entry(title="", data=user_input)
if self.is_new:
return self.async_create_entry(
title=user_input.pop(CONF_NAME),
data=user_input,
)
return self.async_update_and_abort(
self._get_entry(),
self._get_reconfigure_subentry(),
data=user_input,
)
options.update(user_input)
if CONF_LLM_HASS_API in options and CONF_LLM_HASS_API not in user_input:
@@ -194,7 +258,7 @@ class OpenAIOptionsFlow(OptionsFlow):
async def async_step_advanced(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
) -> SubentryFlowResult:
"""Manage advanced options."""
options = self.options
errors: dict[str, str] = {}
@@ -236,7 +300,7 @@ class OpenAIOptionsFlow(OptionsFlow):
async def async_step_model(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
) -> SubentryFlowResult:
"""Manage model-specific options."""
options = self.options
errors: dict[str, str] = {}
@@ -303,7 +367,16 @@ class OpenAIOptionsFlow(OptionsFlow):
}
if not step_schema:
return self.async_create_entry(title="", data=options)
if self.is_new:
return self.async_create_entry(
title=options.pop(CONF_NAME, DEFAULT_CONVERSATION_NAME),
data=options,
)
return self.async_update_and_abort(
self._get_entry(),
self._get_reconfigure_subentry(),
data=options,
)
if user_input is not None:
if user_input.get(CONF_WEB_SEARCH):
@@ -316,7 +389,16 @@ class OpenAIOptionsFlow(OptionsFlow):
options.pop(CONF_WEB_SEARCH_TIMEZONE, None)
options.update(user_input)
return self.async_create_entry(title="", data=options)
if self.is_new:
return self.async_create_entry(
title=options.pop(CONF_NAME, DEFAULT_CONVERSATION_NAME),
data=options,
)
return self.async_update_and_abort(
self._get_entry(),
self._get_reconfigure_subentry(),
data=options,
)
return self.async_show_form(
step_id="model",
@@ -332,7 +414,7 @@ class OpenAIOptionsFlow(OptionsFlow):
zone_home = self.hass.states.get(ENTITY_ID_HOME)
if zone_home is not None:
client = openai.AsyncOpenAI(
api_key=self.config_entry.data[CONF_API_KEY],
api_key=self._get_entry().data[CONF_API_KEY],
http_client=get_async_client(self.hass),
)
location_schema = vol.Schema(

View File

@@ -2,9 +2,15 @@
import logging
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.helpers import llm
DOMAIN = "openai_conversation"
LOGGER: logging.Logger = logging.getLogger(__package__)
DEFAULT_CONVERSATION_NAME = "OpenAI Conversation"
DEFAULT_AI_TASK_NAME = "OpenAI AI Task"
CONF_CHAT_MODEL = "chat_model"
CONF_FILENAMES = "filenames"
CONF_MAX_TOKENS = "max_tokens"
@@ -30,6 +36,16 @@ RECOMMENDED_WEB_SEARCH = False
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE = "medium"
RECOMMENDED_WEB_SEARCH_USER_LOCATION = False
RECOMMENDED_CONVERSATION_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
}
RECOMMENDED_AI_TASK_OPTIONS = {
CONF_RECOMMENDED: True,
}
UNSUPPORTED_MODELS: list[str] = [
"o1-mini",
"o1-mini-2024-09-12",

View File

@@ -1,73 +1,17 @@
"""Conversation support for OpenAI."""
from collections.abc import AsyncGenerator, Callable
import json
from typing import Any, Literal, cast
import openai
from openai._streaming import AsyncStream
from openai.types.responses import (
EasyInputMessageParam,
FunctionToolParam,
ResponseCompletedEvent,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionToolCallParam,
ResponseIncompleteEvent,
ResponseInputParam,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputMessageParam,
ResponseReasoningItem,
ResponseReasoningItemParam,
ResponseStreamEvent,
ResponseTextDeltaEvent,
ToolParam,
WebSearchToolParam,
)
from openai.types.responses.response_input_param import FunctionCallOutput
from openai.types.responses.web_search_tool_param import UserLocation
from voluptuous_openapi import convert
from typing import Literal
from homeassistant.components import assist_pipeline, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, intent, llm
from homeassistant.helpers import intent
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from . import OpenAIConfigEntry
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_REASONING_EFFORT,
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_CITY,
CONF_WEB_SEARCH_CONTEXT_SIZE,
CONF_WEB_SEARCH_COUNTRY,
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
)
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
from .const import CONF_PROMPT, DEFAULT_CONVERSATION_NAME, DOMAIN
from .entity import OpenAILLMBaseEntity
async def async_setup_entry(
@@ -76,175 +20,30 @@ async def async_setup_entry(
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up conversation entities."""
agent = OpenAIConversationEntity(config_entry)
async_add_entities([agent])
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "conversation":
continue
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> FunctionToolParam:
"""Format tool specification."""
return FunctionToolParam(
type="function",
name=tool.name,
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
description=tool.description,
strict=False,
)
def _convert_content_to_param(
content: conversation.Content,
) -> ResponseInputParam:
"""Convert any native chat message for this agent to the native format."""
messages: ResponseInputParam = []
if isinstance(content, conversation.ToolResultContent):
return [
FunctionCallOutput(
type="function_call_output",
call_id=content.tool_call_id,
output=json.dumps(content.tool_result),
)
]
if content.content:
role: Literal["user", "assistant", "system", "developer"] = content.role
if role == "system":
role = "developer"
messages.append(
EasyInputMessageParam(type="message", role=role, content=content.content)
async_add_entities(
[OpenAIConversationEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
messages.extend(
ResponseFunctionToolCallParam(
type="function_call",
name=tool_call.tool_name,
arguments=json.dumps(tool_call.tool_args),
call_id=tool_call.id,
)
for tool_call in content.tool_calls
)
return messages
async def _transform_stream(
chat_log: conversation.ChatLog,
result: AsyncStream[ResponseStreamEvent],
messages: ResponseInputParam,
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform an OpenAI delta stream into HA format."""
async for event in result:
LOGGER.debug("Received event: %s", event)
if isinstance(event, ResponseOutputItemAddedEvent):
if isinstance(event.item, ResponseOutputMessage):
yield {"role": event.item.role}
elif isinstance(event.item, ResponseFunctionToolCall):
# OpenAI has tool calls as individual events
# while HA puts tool calls inside the assistant message.
# We turn them into individual assistant content for HA
# to ensure that tools are called as soon as possible.
yield {"role": "assistant"}
current_tool_call = event.item
elif isinstance(event, ResponseOutputItemDoneEvent):
item = event.item.model_dump()
item.pop("status", None)
if isinstance(event.item, ResponseReasoningItem):
messages.append(cast(ResponseReasoningItemParam, item))
elif isinstance(event.item, ResponseOutputMessage):
messages.append(cast(ResponseOutputMessageParam, item))
elif isinstance(event.item, ResponseFunctionToolCall):
messages.append(cast(ResponseFunctionToolCallParam, item))
elif isinstance(event, ResponseTextDeltaEvent):
yield {"content": event.delta}
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
current_tool_call.arguments += event.delta
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
current_tool_call.status = "completed"
yield {
"tool_calls": [
llm.ToolInput(
id=current_tool_call.call_id,
tool_name=current_tool_call.name,
tool_args=json.loads(current_tool_call.arguments),
)
]
}
elif isinstance(event, ResponseCompletedEvent):
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
elif isinstance(event, ResponseIncompleteEvent):
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
if (
event.response.incomplete_details
and event.response.incomplete_details.reason
):
reason: str = event.response.incomplete_details.reason
else:
reason = "unknown reason"
if reason == "max_output_tokens":
reason = "max output tokens reached"
elif reason == "content_filter":
reason = "content filter triggered"
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
elif isinstance(event, ResponseFailedEvent):
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
reason = "unknown reason"
if event.response.error is not None:
reason = event.response.error.message
raise HomeAssistantError(f"OpenAI response failed: {reason}")
elif isinstance(event, ResponseErrorEvent):
raise HomeAssistantError(f"OpenAI response error: {event.message}")
class OpenAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
conversation.ConversationEntity,
conversation.AbstractConversationAgent,
OpenAILLMBaseEntity,
):
"""OpenAI conversation agent."""
_attr_has_entity_name = True
_attr_name = None
_attr_supports_streaming = True
def __init__(self, entry: OpenAIConfigEntry) -> None:
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the agent."""
self.entry = entry
self._attr_unique_id = entry.entry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)},
name=entry.title,
manufacturer="OpenAI",
model="ChatGPT",
entry_type=dr.DeviceEntryType.SERVICE,
)
if self.entry.options.get(CONF_LLM_HASS_API):
super().__init__(entry, subentry)
self._attr_name = subentry.title or DEFAULT_CONVERSATION_NAME
if self.subentry.data.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
)
@@ -276,7 +75,7 @@ class OpenAIConversationEntity(
chat_log: conversation.ChatLog,
) -> conversation.ConversationResult:
"""Process the user input and call the API."""
options = self.entry.options
options = self.subentry.data
try:
await chat_log.async_provide_llm_data(
@@ -299,91 +98,6 @@ class OpenAIConversationEntity(
continue_conversation=chat_log.continue_conversation,
)
async def _async_handle_chat_log(
self,
chat_log: conversation.ChatLog,
) -> None:
"""Generate an answer for the chat log."""
options = self.entry.options
tools: list[ToolParam] | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
if options.get(CONF_WEB_SEARCH):
web_search = WebSearchToolParam(
type="web_search_preview",
search_context_size=options.get(
CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
),
)
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
web_search["user_location"] = UserLocation(
type="approximate",
city=options.get(CONF_WEB_SEARCH_CITY, ""),
region=options.get(CONF_WEB_SEARCH_REGION, ""),
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
)
if tools is None:
tools = []
tools.append(web_search)
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
messages = [
m
for content in chat_log.content
for m in _convert_content_to_param(content)
]
client = self.entry.runtime_data
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
model_args = {
"model": model,
"input": messages,
"max_output_tokens": options.get(
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
),
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": chat_log.conversation_id,
"stream": True,
}
if tools:
model_args["tools"] = tools
if model.startswith("o"):
model_args["reasoning"] = {
"effort": options.get(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
)
}
else:
model_args["store"] = False
try:
result = await client.responses.create(**model_args)
except openai.RateLimitError as err:
LOGGER.error("Rate limited by OpenAI: %s", err)
raise HomeAssistantError("Rate limited or insufficient funds") from err
except openai.OpenAIError as err:
LOGGER.error("Error talking to OpenAI: %s", err)
raise HomeAssistantError("Error talking to OpenAI") from err
async for content in chat_log.async_add_delta_content_stream(
self.entity_id, _transform_stream(chat_log, result, messages)
):
if not isinstance(content, conversation.AssistantContent):
messages.extend(_convert_content_to_param(content))
if not chat_log.unresponded_tool_results:
break
async def _async_entry_update_listener(
self, hass: HomeAssistant, entry: ConfigEntry
) -> None:

View File

@@ -0,0 +1,313 @@
"""Base class for OpenAI Conversation entities."""
from collections.abc import AsyncGenerator, Callable
import json
from typing import Any, Literal, cast
import openai
from openai._streaming import AsyncStream
from openai.types.responses import (
EasyInputMessageParam,
FunctionToolParam,
ResponseCompletedEvent,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionToolCallParam,
ResponseIncompleteEvent,
ResponseInputParam,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputMessageParam,
ResponseReasoningItem,
ResponseReasoningItemParam,
ResponseStreamEvent,
ResponseTextDeltaEvent,
ToolParam,
WebSearchToolParam,
)
from openai.types.responses.response_input_param import FunctionCallOutput
from openai.types.responses.web_search_tool_param import UserLocation
from voluptuous_openapi import convert
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigSubentry
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm
from homeassistant.helpers.entity import Entity
from . import OpenAIConfigEntry
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_REASONING_EFFORT,
CONF_TEMPERATURE,
CONF_TOP_P,
CONF_WEB_SEARCH,
CONF_WEB_SEARCH_CITY,
CONF_WEB_SEARCH_CONTEXT_SIZE,
CONF_WEB_SEARCH_COUNTRY,
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_REASONING_EFFORT,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
)
# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10
def _format_tool(
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
) -> FunctionToolParam:
"""Format tool specification."""
return FunctionToolParam(
type="function",
name=tool.name,
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
description=tool.description,
strict=False,
)
def _convert_content_to_param(
content: conversation.Content,
) -> ResponseInputParam:
"""Convert any native chat message for this agent to the native format."""
messages: ResponseInputParam = []
if isinstance(content, conversation.ToolResultContent):
return [
FunctionCallOutput(
type="function_call_output",
call_id=content.tool_call_id,
output=json.dumps(content.tool_result),
)
]
if content.content:
role: Literal["user", "assistant", "system", "developer"] = content.role
if role == "system":
role = "developer"
messages.append(
EasyInputMessageParam(type="message", role=role, content=content.content)
)
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
messages.extend(
ResponseFunctionToolCallParam(
type="function_call",
name=tool_call.tool_name,
arguments=json.dumps(tool_call.tool_args),
call_id=tool_call.id,
)
for tool_call in content.tool_calls
)
return messages
async def _transform_stream(
chat_log: conversation.ChatLog,
result: AsyncStream[ResponseStreamEvent],
messages: ResponseInputParam,
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
"""Transform an OpenAI delta stream into HA format."""
async for event in result:
LOGGER.debug("Received event: %s", event)
if isinstance(event, ResponseOutputItemAddedEvent):
if isinstance(event.item, ResponseOutputMessage):
yield {"role": event.item.role}
elif isinstance(event.item, ResponseFunctionToolCall):
# OpenAI has tool calls as individual events
# while HA puts tool calls inside the assistant message.
# We turn them into individual assistant content for HA
# to ensure that tools are called as soon as possible.
yield {"role": "assistant"}
current_tool_call = event.item
elif isinstance(event, ResponseOutputItemDoneEvent):
item = event.item.model_dump()
item.pop("status", None)
if isinstance(event.item, ResponseReasoningItem):
messages.append(cast(ResponseReasoningItemParam, item))
elif isinstance(event.item, ResponseOutputMessage):
messages.append(cast(ResponseOutputMessageParam, item))
elif isinstance(event.item, ResponseFunctionToolCall):
messages.append(cast(ResponseFunctionToolCallParam, item))
elif isinstance(event, ResponseTextDeltaEvent):
yield {"content": event.delta}
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
current_tool_call.arguments += event.delta
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
current_tool_call.status = "completed"
yield {
"tool_calls": [
llm.ToolInput(
id=current_tool_call.call_id,
tool_name=current_tool_call.name,
tool_args=json.loads(current_tool_call.arguments),
)
]
}
elif isinstance(event, ResponseCompletedEvent):
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
elif isinstance(event, ResponseIncompleteEvent):
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
if (
event.response.incomplete_details
and event.response.incomplete_details.reason
):
reason: str = event.response.incomplete_details.reason
else:
reason = "unknown reason"
if reason == "max_output_tokens":
reason = "max output tokens reached"
elif reason == "content_filter":
reason = "content filter triggered"
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
elif isinstance(event, ResponseFailedEvent):
if event.response.usage is not None:
chat_log.async_trace(
{
"stats": {
"input_tokens": event.response.usage.input_tokens,
"output_tokens": event.response.usage.output_tokens,
}
}
)
reason = "unknown reason"
if event.response.error is not None:
reason = event.response.error.message
raise HomeAssistantError(f"OpenAI response failed: {reason}")
elif isinstance(event, ResponseErrorEvent):
raise HomeAssistantError(f"OpenAI response error: {event.message}")
class OpenAILLMBaseEntity(Entity):
"""OpenAI conversation agent."""
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the agent."""
self.entry = entry
self.subentry = subentry
self._attr_unique_id = subentry.subentry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)},
name=entry.title,
manufacturer="OpenAI",
model="ChatGPT",
entry_type=dr.DeviceEntryType.SERVICE,
)
async def _async_handle_chat_log(
self,
chat_log: conversation.ChatLog,
) -> None:
"""Generate an answer for the chat log."""
options = self.subentry.data
tools: list[ToolParam] | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
if options.get(CONF_WEB_SEARCH):
web_search = WebSearchToolParam(
type="web_search_preview",
search_context_size=options.get(
CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
),
)
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
web_search["user_location"] = UserLocation(
type="approximate",
city=options.get(CONF_WEB_SEARCH_CITY, ""),
region=options.get(CONF_WEB_SEARCH_REGION, ""),
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
)
if tools is None:
tools = []
tools.append(web_search)
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
messages = [
m
for content in chat_log.content
for m in _convert_content_to_param(content)
]
client = self.entry.runtime_data
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
model_args = {
"model": model,
"input": messages,
"max_output_tokens": options.get(
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
),
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
"user": chat_log.conversation_id,
"stream": True,
}
if tools:
model_args["tools"] = tools
if model.startswith("o"):
model_args["reasoning"] = {
"effort": options.get(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
)
}
else:
model_args["store"] = False
try:
result = await client.responses.create(**model_args)
except openai.RateLimitError as err:
LOGGER.error("Rate limited by OpenAI: %s", err)
raise HomeAssistantError("Rate limited or insufficient funds") from err
except openai.OpenAIError as err:
LOGGER.error("Error talking to OpenAI: %s", err)
raise HomeAssistantError("Error talking to OpenAI") from err
async for content in chat_log.async_add_delta_content_stream(
self.entity_id, _transform_stream(chat_log, result, messages)
):
if not isinstance(content, conversation.AssistantContent):
messages.extend(_convert_content_to_param(content))
if not chat_log.unresponded_tool_results:
break

View File

@@ -13,45 +13,102 @@
"unknown": "[%key:common::config_flow::error::unknown%]"
}
},
"options": {
"step": {
"init": {
"data": {
"prompt": "Instructions",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"recommended": "Recommended model settings"
"config_subentries": {
"conversation": {
"initiate_flow": {
"user": "Add conversation agent",
"reconfigure": "Reconfigure conversation agent"
},
"entry_type": "Conversation agent",
"step": {
"init": {
"data": {
"name": "[%key:common::config_flow::data::name%]",
"prompt": "Instructions",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"recommended": "Recommended model settings"
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template."
}
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template."
"advanced": {
"title": "Advanced settings",
"data": {
"chat_model": "[%key:common::generic::model%]",
"max_tokens": "Maximum tokens to return in response",
"temperature": "Temperature",
"top_p": "Top P"
}
},
"model": {
"title": "Model-specific options",
"data": {
"reasoning_effort": "Reasoning effort",
"web_search": "Enable web search",
"search_context_size": "Search context size",
"user_location": "Include home location"
},
"data_description": {
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt",
"web_search": "Allow the model to search the web for the latest information before generating a response",
"search_context_size": "High level guidance for the amount of context window space to use for the search",
"user_location": "Refine search results based on geography"
}
}
},
"advanced": {
"title": "Advanced settings",
"data": {
"chat_model": "[%key:common::generic::model%]",
"max_tokens": "Maximum tokens to return in response",
"temperature": "Temperature",
"top_p": "Top P"
}
"abort": {
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
},
"model": {
"title": "Model-specific options",
"data": {
"reasoning_effort": "Reasoning effort",
"web_search": "Enable web search",
"search_context_size": "Search context size",
"user_location": "Include home location"
},
"data_description": {
"reasoning_effort": "How many reasoning tokens the model should generate before creating a response to the prompt",
"web_search": "Allow the model to search the web for the latest information before generating a response",
"search_context_size": "High level guidance for the amount of context window space to use for the search",
"user_location": "Refine search results based on geography"
}
"error": {
"model_not_supported": "This model is not supported, please select a different model"
}
},
"error": {
"model_not_supported": "This model is not supported, please select a different model"
"ai_task": {
"initiate_flow": {
"user": "Add AI task service",
"reconfigure": "Reconfigure AI task service"
},
"entry_type": "AI task service",
"step": {
"init": {
"data": {
"name": "[%key:common::config_flow::data::name%]",
"recommended": "[%key:component::openai_conversation::config_subentries::conversation::step::init::data::recommended%]"
}
},
"advanced": {
"title": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::title%]",
"data": {
"chat_model": "[%key:common::generic::model%]",
"max_tokens": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::max_tokens%]",
"temperature": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::temperature%]",
"top_p": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::top_p%]"
}
},
"model": {
"title": "[%key:component::openai_conversation::config_subentries::conversation::step::model::title%]",
"data": {
"reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::reasoning_effort%]",
"web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::web_search%]",
"search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::search_context_size%]",
"user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::user_location%]"
},
"data_description": {
"reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::reasoning_effort%]",
"web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::web_search%]",
"search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::search_context_size%]",
"user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::user_location%]"
}
}
},
"abort": {
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
},
"error": {
"model_not_supported": "[%key:component::openai_conversation::config_subentries::conversation::error::model_not_supported%]"
}
}
},
"selector": {

View File

@@ -3420,6 +3420,11 @@ class ConfigSubentryFlow(
"""Return config entry id."""
return self.handler[0]
@property
def _subentry_type(self) -> str:
"""Return type of subentry we are editing/creating."""
return self.handler[1]
@callback
def _get_entry(self) -> ConfigEntry:
"""Return the config entry linked to the current context."""

View File

@@ -0,0 +1,87 @@
"""Common utilities for OpenAI conversation tests."""
from openai.types.responses import (
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputText,
ResponseStreamEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
)
def create_message_item(
id: str, text: str | list[str], output_index: int
) -> list[ResponseStreamEvent]:
"""Create a message item."""
if isinstance(text, str):
text = [text]
content = ResponseOutputText(annotations=[], text="", type="output_text")
events = [
ResponseOutputItemAddedEvent(
item=ResponseOutputMessage(
id=id,
content=[],
type="message",
role="assistant",
status="in_progress",
),
output_index=output_index,
type="response.output_item.added",
),
ResponseContentPartAddedEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
type="response.content_part.added",
),
]
content.text = "".join(text)
events.extend(
ResponseTextDeltaEvent(
content_index=0,
delta=delta,
item_id=id,
output_index=output_index,
type="response.output_text.delta",
)
for delta in text
)
events.extend(
[
ResponseTextDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
text="".join(text),
type="response.output_text.done",
),
ResponseContentPartDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
type="response.content_part.done",
),
ResponseOutputItemDoneEvent(
item=ResponseOutputMessage(
id=id,
content=[content],
role="assistant",
status="completed",
type="message",
),
output_index=output_index,
type="response.output_item.done",
),
]
)
return events

View File

@@ -1,15 +1,35 @@
"""Tests helpers."""
from unittest.mock import patch
from collections.abc import Generator
from unittest.mock import AsyncMock, patch
from openai.types import ResponseFormatText
from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseCreatedEvent,
ResponseError,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseOutputItemDoneEvent,
ResponseTextConfig,
)
from openai.types.responses.response import IncompleteDetails
import pytest
from homeassistant.components.openai_conversation.const import (
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
)
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
from tests.components.conversation import mock_chat_log # noqa: F401
@pytest.fixture
@@ -21,6 +41,21 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
data={
"api_key": "bla",
},
version=2,
subentries_data=[
{
"data": {},
"subentry_type": "conversation",
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
},
{
"data": {},
"subentry_type": "ai_task",
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
],
)
entry.add_to_hass(hass)
return entry
@@ -31,8 +66,10 @@ def mock_config_entry_with_assist(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> MockConfigEntry:
"""Mock a config entry with assist."""
hass.config_entries.async_update_entry(
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}
hass.config_entries.async_update_subentry(
mock_config_entry,
next(iter(mock_config_entry.subentries.values())),
data={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
)
return mock_config_entry
@@ -53,3 +90,89 @@ async def mock_init_component(
async def setup_ha(hass: HomeAssistant) -> None:
"""Set up Home Assistant."""
assert await async_setup_component(hass, "homeassistant", {})
@pytest.fixture
def mock_create_stream() -> Generator[AsyncMock]:
"""Mock stream response."""
async def mock_generator(events, **kwargs):
response = Response(
id="resp_A",
created_at=1700000000,
error=None,
incomplete_details=None,
instructions=kwargs.get("instructions"),
metadata=kwargs.get("metadata", {}),
model=kwargs.get("model", "gpt-4o-mini"),
object="response",
output=[],
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
temperature=kwargs.get("temperature", 1.0),
tool_choice=kwargs.get("tool_choice", "auto"),
tools=kwargs.get("tools", []),
top_p=kwargs.get("top_p", 1.0),
max_output_tokens=kwargs.get("max_output_tokens", 100000),
previous_response_id=kwargs.get("previous_response_id"),
reasoning=kwargs.get("reasoning"),
status="in_progress",
text=kwargs.get(
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
),
truncation=kwargs.get("truncation", "disabled"),
usage=None,
user=kwargs.get("user"),
store=kwargs.get("store", True),
)
yield ResponseCreatedEvent(
response=response,
type="response.created",
)
yield ResponseInProgressEvent(
response=response,
type="response.in_progress",
)
response.status = "completed"
for value in events:
if isinstance(value, ResponseOutputItemDoneEvent):
response.output.append(value.item)
elif isinstance(value, IncompleteDetails):
response.status = "incomplete"
response.incomplete_details = value
break
if isinstance(value, ResponseError):
response.status = "failed"
response.error = value
break
yield value
if isinstance(value, ResponseErrorEvent):
return
if response.status == "incomplete":
yield ResponseIncompleteEvent(
response=response,
type="response.incomplete",
)
elif response.status == "failed":
yield ResponseFailedEvent(
response=response,
type="response.failed",
)
else:
yield ResponseCompletedEvent(
response=response,
type="response.completed",
)
with patch(
"openai.resources.responses.AsyncResponses.create",
AsyncMock(),
) as mock_create:
mock_create.side_effect = lambda **kwargs: mock_generator(
mock_create.return_value.pop(0), **kwargs
)
yield mock_create

View File

@@ -6,7 +6,7 @@
'role': 'user',
}),
dict({
'agent_id': 'conversation.openai',
'agent_id': 'conversation.openai_conversation',
'content': None,
'role': 'assistant',
'tool_calls': list([
@@ -20,14 +20,14 @@
]),
}),
dict({
'agent_id': 'conversation.openai',
'agent_id': 'conversation.openai_conversation',
'role': 'tool_result',
'tool_call_id': 'call_call_1',
'tool_name': 'test_tool',
'tool_result': 'value1',
}),
dict({
'agent_id': 'conversation.openai',
'agent_id': 'conversation.openai_conversation',
'content': None,
'role': 'assistant',
'tool_calls': list([
@@ -41,14 +41,14 @@
]),
}),
dict({
'agent_id': 'conversation.openai',
'agent_id': 'conversation.openai_conversation',
'role': 'tool_result',
'tool_call_id': 'call_call_2',
'tool_name': 'test_tool',
'tool_result': 'value2',
}),
dict({
'agent_id': 'conversation.openai',
'agent_id': 'conversation.openai_conversation',
'content': 'Cool',
'role': 'assistant',
'tool_calls': None,
@@ -62,7 +62,7 @@
'role': 'user',
}),
dict({
'agent_id': 'conversation.openai',
'agent_id': 'conversation.openai_conversation',
'content': None,
'role': 'assistant',
'tool_calls': list([
@@ -76,14 +76,14 @@
]),
}),
dict({
'agent_id': 'conversation.openai',
'agent_id': 'conversation.openai_conversation',
'role': 'tool_result',
'tool_call_id': 'call_call_1',
'tool_name': 'test_tool',
'tool_result': 'value1',
}),
dict({
'agent_id': 'conversation.openai',
'agent_id': 'conversation.openai_conversation',
'content': 'Cool',
'role': 'assistant',
'tool_calls': None,

View File

@@ -0,0 +1,33 @@
"""Test AI Task platform of OpenAI Conversation integration."""
from unittest.mock import AsyncMock
import pytest
from homeassistant.components import ai_task
from homeassistant.core import HomeAssistant
from .common import create_message_item
from tests.common import MockConfigEntry
@pytest.mark.usefixtures("mock_init_component")
async def test_ai_task_generate_text(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_create_stream: AsyncMock,
) -> None:
"""Test that AI task can generate text."""
entity_id = "ai_task.openai_ai_task"
mock_create_stream.return_value = [
create_message_item(id="msg_A", text="Hi there!", output_index=0)
]
result = await ai_task.async_generate_text(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Test prompt",
)
assert result.text == "Hi there!"

View File

@@ -8,7 +8,6 @@ from openai.types.responses import Response, ResponseOutputMessage, ResponseOutp
import pytest
from homeassistant import config_entries
from homeassistant.components.openai_conversation.config_flow import RECOMMENDED_OPTIONS
from homeassistant.components.openai_conversation.const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
@@ -24,8 +23,12 @@ from homeassistant.components.openai_conversation.const import (
CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TOP_P,
)
@@ -72,42 +75,128 @@ async def test_form(hass: HomeAssistant) -> None:
assert result2["data"] == {
"api_key": "bla",
}
assert result2["options"] == RECOMMENDED_OPTIONS
assert result2["options"] == {}
assert result2["subentries"] == [
{
"subentry_type": "conversation",
"data": RECOMMENDED_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
},
{
"subentry_type": "ai_task",
"data": RECOMMENDED_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
]
assert len(mock_setup_entry.mock_calls) == 1
async def test_options_recommended(
async def test_creating_conversation_subentry(
hass: HomeAssistant,
mock_init_component: None,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating a conversation subentry."""
mock_config_entry.add_to_hass(hass)
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "init"
assert not result["errors"]
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{"name": "My Custom Agent", **RECOMMENDED_CONVERSATION_OPTIONS},
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "My Custom Agent"
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
assert result2["data"] == processed_options
async def test_creating_ai_task_subentry(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test the options flow with recommended settings."""
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
"""Test creating an AI task subentry."""
with patch("openai.resources.models.AsyncModels.list"):
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task"),
context={"source": config_entries.SOURCE_USER},
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "init"
assert not result["errors"]
old_subentries = set(mock_config_entry.subentries)
with patch("openai.resources.models.AsyncModels.list"):
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{"name": "My AI Task", **RECOMMENDED_AI_TASK_OPTIONS},
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "My AI Task"
assert result2["data"] == RECOMMENDED_AI_TASK_OPTIONS
assert len(mock_config_entry.subentries) == 3
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id]
assert new_subentry.subentry_type == "ai_task"
assert new_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
assert new_subentry.title == "My AI Task"
async def test_subentry_recommended(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test the subentry flow with recommended settings."""
subentry = next(iter(mock_config_entry.subentries.values()))
subentry_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_type, subentry.subentry_id
)
options = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
options = await hass.config_entries.subentries.async_configure(
subentry_flow["flow_id"],
{
"prompt": "Speak like a pirate",
"recommended": True,
},
)
await hass.async_block_till_done()
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"]["prompt"] == "Speak like a pirate"
assert options["type"] is FlowResultType.ABORT
assert options["reason"] == "reconfigure_successful"
assert subentry.data["prompt"] == "Speak like a pirate"
async def test_options_unsupported_model(
async def test_subentry_unsupported_model(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test the options form giving error about models not supported."""
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
"""Test the subentry form giving error about models not supported."""
subentry = next(iter(mock_config_entry.subentries.values()))
subentry_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_type, subentry.subentry_id
)
assert options_flow["type"] == FlowResultType.FORM
assert options_flow["step_id"] == "init"
assert subentry_flow["type"] == FlowResultType.FORM
assert subentry_flow["step_id"] == "init"
# Configure initial step
options_flow = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
subentry_flow = await hass.config_entries.subentries.async_configure(
subentry_flow["flow_id"],
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
@@ -115,19 +204,19 @@ async def test_options_unsupported_model(
},
)
await hass.async_block_till_done()
assert options_flow["type"] == FlowResultType.FORM
assert options_flow["step_id"] == "advanced"
assert subentry_flow["type"] == FlowResultType.FORM
assert subentry_flow["step_id"] == "advanced"
# Configure advanced step
options_flow = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
subentry_flow = await hass.config_entries.subentries.async_configure(
subentry_flow["flow_id"],
{
CONF_CHAT_MODEL: "o1-mini",
},
)
await hass.async_block_till_done()
assert options_flow["type"] is FlowResultType.FORM
assert options_flow["errors"] == {"chat_model": "model_not_supported"}
assert subentry_flow["type"] is FlowResultType.FORM
assert subentry_flow["errors"] == {"chat_model": "model_not_supported"}
@pytest.mark.parametrize(
@@ -494,7 +583,7 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non
),
],
)
async def test_options_switching(
async def test_subentry_switching(
hass: HomeAssistant,
mock_config_entry,
mock_init_component,
@@ -502,16 +591,21 @@ async def test_options_switching(
new_options,
expected_options,
) -> None:
"""Test the options form."""
hass.config_entries.async_update_entry(mock_config_entry, options=current_options)
options = await hass.config_entries.options.async_init(mock_config_entry.entry_id)
assert options["step_id"] == "init"
"""Test the subentry form."""
subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry(
mock_config_entry, subentry, data=current_options
)
subentry_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_type, subentry.subentry_id
)
assert subentry_flow["step_id"] == "init"
for step_options in new_options:
assert options["type"] == FlowResultType.FORM
assert subentry_flow["type"] == FlowResultType.FORM
# Test that current options are showed as suggested values:
for key in options["data_schema"].schema:
for key in subentry_flow["data_schema"].schema:
if (
isinstance(key.description, dict)
and "suggested_value" in key.description
@@ -523,38 +617,42 @@ async def test_options_switching(
assert key.description["suggested_value"] == current_option
# Configure current step
options = await hass.config_entries.options.async_configure(
options["flow_id"],
subentry_flow = await hass.config_entries.subentries.async_configure(
subentry_flow["flow_id"],
step_options,
)
await hass.async_block_till_done()
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == expected_options
assert subentry_flow["type"] is FlowResultType.ABORT
assert subentry_flow["reason"] == "reconfigure_successful"
assert subentry.data == expected_options
async def test_options_web_search_user_location(
async def test_subentry_web_search_user_location(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test fetching user location."""
options = await hass.config_entries.options.async_init(mock_config_entry.entry_id)
assert options["type"] == FlowResultType.FORM
assert options["step_id"] == "init"
subentry = next(iter(mock_config_entry.subentries.values()))
subentry_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_type, subentry.subentry_id
)
assert subentry_flow["type"] == FlowResultType.FORM
assert subentry_flow["step_id"] == "init"
# Configure initial step
options = await hass.config_entries.options.async_configure(
options["flow_id"],
subentry_flow = await hass.config_entries.subentries.async_configure(
subentry_flow["flow_id"],
{
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
},
)
assert options["type"] == FlowResultType.FORM
assert options["step_id"] == "advanced"
assert subentry_flow["type"] == FlowResultType.FORM
assert subentry_flow["step_id"] == "advanced"
# Configure advanced step
options = await hass.config_entries.options.async_configure(
options["flow_id"],
subentry_flow = await hass.config_entries.subentries.async_configure(
subentry_flow["flow_id"],
{
CONF_TEMPERATURE: 1.0,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL,
@@ -563,8 +661,8 @@ async def test_options_web_search_user_location(
},
)
await hass.async_block_till_done()
assert options["type"] == FlowResultType.FORM
assert options["step_id"] == "model"
assert subentry_flow["type"] == FlowResultType.FORM
assert subentry_flow["step_id"] == "model"
hass.config.country = "US"
hass.config.time_zone = "America/Los_Angeles"
@@ -601,8 +699,8 @@ async def test_options_web_search_user_location(
)
# Configure model step
options = await hass.config_entries.options.async_configure(
options["flow_id"],
subentry_flow = await hass.config_entries.subentries.async_configure(
subentry_flow["flow_id"],
{
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_CONTEXT_SIZE: "medium",
@@ -614,8 +712,9 @@ async def test_options_web_search_user_location(
mock_create.call_args.kwargs["input"][0]["content"] == "Where are the following"
" coordinates located: (37.7749, -122.4194)?"
)
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == {
assert subentry_flow["type"] is FlowResultType.ABORT
assert subentry_flow["reason"] == "reconfigure_successful"
assert subentry.data == {
CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 1.0,

View File

@@ -1,35 +1,20 @@
"""Tests for the OpenAI integration."""
from collections.abc import Generator
from unittest.mock import AsyncMock, patch
import httpx
from openai import AuthenticationError, RateLimitError
from openai.types import ResponseFormatText
from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseError,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionWebSearch,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
ResponseStreamEvent,
ResponseTextConfig,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
@@ -54,6 +39,8 @@ from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
from .common import create_message_item
from tests.common import MockConfigEntry
from tests.components.conversation import (
MockChatLog,
@@ -61,112 +48,24 @@ from tests.components.conversation import (
)
@pytest.fixture
def mock_create_stream() -> Generator[AsyncMock]:
"""Mock stream response."""
async def mock_generator(events, **kwargs):
response = Response(
id="resp_A",
created_at=1700000000,
error=None,
incomplete_details=None,
instructions=kwargs.get("instructions"),
metadata=kwargs.get("metadata", {}),
model=kwargs.get("model", "gpt-4o-mini"),
object="response",
output=[],
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
temperature=kwargs.get("temperature", 1.0),
tool_choice=kwargs.get("tool_choice", "auto"),
tools=kwargs.get("tools"),
top_p=kwargs.get("top_p", 1.0),
max_output_tokens=kwargs.get("max_output_tokens", 100000),
previous_response_id=kwargs.get("previous_response_id"),
reasoning=kwargs.get("reasoning"),
status="in_progress",
text=kwargs.get(
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
),
truncation=kwargs.get("truncation", "disabled"),
usage=None,
user=kwargs.get("user"),
store=kwargs.get("store", True),
)
yield ResponseCreatedEvent(
response=response,
type="response.created",
)
yield ResponseInProgressEvent(
response=response,
type="response.in_progress",
)
response.status = "completed"
for value in events:
if isinstance(value, ResponseOutputItemDoneEvent):
response.output.append(value.item)
elif isinstance(value, IncompleteDetails):
response.status = "incomplete"
response.incomplete_details = value
break
if isinstance(value, ResponseError):
response.status = "failed"
response.error = value
break
yield value
if isinstance(value, ResponseErrorEvent):
return
if response.status == "incomplete":
yield ResponseIncompleteEvent(
response=response,
type="response.incomplete",
)
elif response.status == "failed":
yield ResponseFailedEvent(
response=response,
type="response.failed",
)
else:
yield ResponseCompletedEvent(
response=response,
type="response.completed",
)
with patch(
"openai.resources.responses.AsyncResponses.create",
AsyncMock(),
) as mock_create:
mock_create.side_effect = lambda **kwargs: mock_generator(
mock_create.return_value.pop(0), **kwargs
)
yield mock_create
async def test_entity(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test entity properties."""
state = hass.states.get("conversation.openai")
state = hass.states.get("conversation.openai_conversation")
assert state
assert state.attributes["supported_features"] == 0
hass.config_entries.async_update_entry(
hass.config_entries.async_update_subentry(
mock_config_entry,
options={
**mock_config_entry.options,
CONF_LLM_HASS_API: "assist",
},
next(iter(mock_config_entry.subentries.values())),
data={CONF_LLM_HASS_API: "assist"},
)
await hass.config_entries.async_reload(mock_config_entry.entry_id)
state = hass.states.get("conversation.openai")
state = hass.states.get("conversation.openai_conversation")
assert state
assert (
state.attributes["supported_features"]
@@ -261,7 +160,7 @@ async def test_incomplete_response(
"Please tell me a big story",
"mock-conversation-id",
Context(),
agent_id="conversation.openai",
agent_id="conversation.openai_conversation",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
@@ -285,7 +184,7 @@ async def test_incomplete_response(
"please tell me a big story",
"mock-conversation-id",
Context(),
agent_id="conversation.openai",
agent_id="conversation.openai_conversation",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
@@ -324,7 +223,7 @@ async def test_failed_response(
"next natural number please",
"mock-conversation-id",
Context(),
agent_id="conversation.openai",
agent_id="conversation.openai_conversation",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
@@ -343,80 +242,6 @@ async def test_conversation_agent(
assert agent.supported_languages == "*"
def create_message_item(
id: str, text: str | list[str], output_index: int
) -> list[ResponseStreamEvent]:
"""Create a message item."""
if isinstance(text, str):
text = [text]
content = ResponseOutputText(annotations=[], text="", type="output_text")
events = [
ResponseOutputItemAddedEvent(
item=ResponseOutputMessage(
id=id,
content=[],
type="message",
role="assistant",
status="in_progress",
),
output_index=output_index,
type="response.output_item.added",
),
ResponseContentPartAddedEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
type="response.content_part.added",
),
]
content.text = "".join(text)
events.extend(
ResponseTextDeltaEvent(
content_index=0,
delta=delta,
item_id=id,
output_index=output_index,
type="response.output_text.delta",
)
for delta in text
)
events.extend(
[
ResponseTextDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
text="".join(text),
type="response.output_text.done",
),
ResponseContentPartDoneEvent(
content_index=0,
item_id=id,
output_index=output_index,
part=content,
type="response.content_part.done",
),
ResponseOutputItemDoneEvent(
item=ResponseOutputMessage(
id=id,
content=[content],
role="assistant",
status="completed",
type="message",
),
output_index=output_index,
type="response.output_item.done",
),
]
)
return events
def create_function_tool_call_item(
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
) -> list[ResponseStreamEvent]:
@@ -583,7 +408,7 @@ async def test_function_call(
"Please call the test function",
mock_chat_log.conversation_id,
Context(),
agent_id="conversation.openai",
agent_id="conversation.openai_conversation",
)
assert mock_create_stream.call_args.kwargs["input"][2] == {
@@ -630,7 +455,7 @@ async def test_function_call_without_reasoning(
"Please call the test function",
mock_chat_log.conversation_id,
Context(),
agent_id="conversation.openai",
agent_id="conversation.openai_conversation",
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
@@ -686,7 +511,7 @@ async def test_function_call_invalid(
"Please call the test function",
"mock-conversation-id",
Context(),
agent_id="conversation.openai",
agent_id="conversation.openai_conversation",
)
@@ -720,7 +545,7 @@ async def test_assist_api_tools_conversion(
]
await conversation.async_converse(
hass, "hello", None, Context(), agent_id="conversation.openai"
hass, "hello", None, Context(), agent_id="conversation.openai_conversation"
)
tools = mock_create_stream.mock_calls[0][2]["tools"]
@@ -735,10 +560,12 @@ async def test_web_search(
mock_chat_log: MockChatLog, # noqa: F811
) -> None:
"""Test web_search_tool."""
hass.config_entries.async_update_entry(
subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry(
mock_config_entry,
options={
**mock_config_entry.options,
subentry,
data={
**subentry.data,
CONF_WEB_SEARCH: True,
CONF_WEB_SEARCH_CONTEXT_SIZE: "low",
CONF_WEB_SEARCH_USER_LOCATION: True,
@@ -764,7 +591,7 @@ async def test_web_search(
"What's on the latest news?",
mock_chat_log.conversation_id,
Context(),
agent_id="conversation.openai",
agent_id="conversation.openai_conversation",
)
assert mock_create_stream.mock_calls[0][2]["tools"] == [