forked from home-assistant/core
Add AI Task platform to OpenAI
This commit is contained in:
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import base64
|
import base64
|
||||||
from mimetypes import guess_file_type
|
from mimetypes import guess_file_type
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import MappingProxyType
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
from openai.types.images_response import ImagesResponse
|
from openai.types.images_response import ImagesResponse
|
||||||
@@ -48,9 +49,11 @@ from .const import (
|
|||||||
CONF_REASONING_EFFORT,
|
CONF_REASONING_EFFORT,
|
||||||
CONF_TEMPERATURE,
|
CONF_TEMPERATURE,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
|
DEFAULT_AI_TASK_NAME,
|
||||||
DEFAULT_CONVERSATION_NAME,
|
DEFAULT_CONVERSATION_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
LOGGER,
|
LOGGER,
|
||||||
|
RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
RECOMMENDED_REASONING_EFFORT,
|
RECOMMENDED_REASONING_EFFORT,
|
||||||
@@ -61,7 +64,10 @@ from .const import (
|
|||||||
SERVICE_GENERATE_IMAGE = "generate_image"
|
SERVICE_GENERATE_IMAGE = "generate_image"
|
||||||
SERVICE_GENERATE_CONTENT = "generate_content"
|
SERVICE_GENERATE_CONTENT = "generate_content"
|
||||||
|
|
||||||
PLATFORMS = (Platform.CONVERSATION,)
|
PLATFORMS = (
|
||||||
|
Platform.AI_TASK,
|
||||||
|
Platform.CONVERSATION,
|
||||||
|
)
|
||||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||||
|
|
||||||
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
|
type OpenAIConfigEntry = ConfigEntry[openai.AsyncClient]
|
||||||
@@ -295,7 +301,7 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) ->
|
|||||||
if entry.version == 1:
|
if entry.version == 1:
|
||||||
# Migrate from version 1 to version 2
|
# Migrate from version 1 to version 2
|
||||||
# Move conversation-specific options to a subentry
|
# Move conversation-specific options to a subentry
|
||||||
subentry = ConfigSubentry(
|
conversation_subentry = ConfigSubentry(
|
||||||
data=entry.options,
|
data=entry.options,
|
||||||
subentry_type="conversation",
|
subentry_type="conversation",
|
||||||
title=DEFAULT_CONVERSATION_NAME,
|
title=DEFAULT_CONVERSATION_NAME,
|
||||||
@@ -303,7 +309,16 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) ->
|
|||||||
)
|
)
|
||||||
hass.config_entries.async_add_subentry(
|
hass.config_entries.async_add_subentry(
|
||||||
entry,
|
entry,
|
||||||
subentry,
|
conversation_subentry,
|
||||||
|
)
|
||||||
|
hass.config_entries.async_add_subentry(
|
||||||
|
entry,
|
||||||
|
ConfigSubentry(
|
||||||
|
data=MappingProxyType(RECOMMENDED_AI_TASK_OPTIONS),
|
||||||
|
subentry_type="ai_task",
|
||||||
|
title=DEFAULT_AI_TASK_NAME,
|
||||||
|
unique_id=None,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Migrate conversation entity to be linked to subentry
|
# Migrate conversation entity to be linked to subentry
|
||||||
@@ -312,8 +327,8 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) ->
|
|||||||
if entity_entry.domain == Platform.CONVERSATION:
|
if entity_entry.domain == Platform.CONVERSATION:
|
||||||
ent_reg.async_update_entity(
|
ent_reg.async_update_entity(
|
||||||
entity_entry.entity_id,
|
entity_entry.entity_id,
|
||||||
config_subentry_id=subentry.subentry_id,
|
config_subentry_id=conversation_subentry.subentry_id,
|
||||||
new_unique_id=subentry.subentry_id,
|
new_unique_id=conversation_subentry.subentry_id,
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
62
homeassistant/components/openai_conversation/ai_task.py
Normal file
62
homeassistant/components/openai_conversation/ai_task.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""AI Task integration for OpenAI Conversation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from homeassistant.components import ai_task, conversation
|
||||||
|
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
|
||||||
|
from . import OpenAIConfigEntry
|
||||||
|
from .const import DEFAULT_AI_TASK_NAME, LOGGER
|
||||||
|
from .entity import OpenAILLMBaseEntity
|
||||||
|
|
||||||
|
ERROR_GETTING_RESPONSE = "Sorry, I had a problem getting a response from OpenAI."
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up AI Task entities."""
|
||||||
|
for subentry in config_entry.subentries.values():
|
||||||
|
if subentry.subentry_type != "ai_task":
|
||||||
|
continue
|
||||||
|
|
||||||
|
async_add_entities(
|
||||||
|
[OpenAILLMTaskEntity(config_entry, subentry)],
|
||||||
|
config_subentry_id=subentry.subentry_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAILLMTaskEntity(ai_task.AITaskEntity, OpenAILLMBaseEntity):
|
||||||
|
"""OpenAI AI Task entity."""
|
||||||
|
|
||||||
|
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_TEXT
|
||||||
|
|
||||||
|
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
|
||||||
|
"""Initialize the agent."""
|
||||||
|
super().__init__(entry, subentry)
|
||||||
|
self._attr_name = subentry.title or DEFAULT_AI_TASK_NAME
|
||||||
|
|
||||||
|
async def _async_generate_text(
|
||||||
|
self,
|
||||||
|
task: ai_task.GenTextTask,
|
||||||
|
chat_log: conversation.ChatLog,
|
||||||
|
) -> ai_task.GenTextTaskResult:
|
||||||
|
"""Handle a generate text task."""
|
||||||
|
await self._async_handle_chat_log(chat_log)
|
||||||
|
|
||||||
|
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||||
|
LOGGER.error(
|
||||||
|
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
|
||||||
|
chat_log.content[-1],
|
||||||
|
)
|
||||||
|
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
|
||||||
|
|
||||||
|
return ai_task.GenTextTaskResult(
|
||||||
|
conversation_id=chat_log.conversation_id,
|
||||||
|
text=chat_log.content[-1].content or "",
|
||||||
|
)
|
@@ -54,9 +54,12 @@ from .const import (
|
|||||||
CONF_WEB_SEARCH_REGION,
|
CONF_WEB_SEARCH_REGION,
|
||||||
CONF_WEB_SEARCH_TIMEZONE,
|
CONF_WEB_SEARCH_TIMEZONE,
|
||||||
CONF_WEB_SEARCH_USER_LOCATION,
|
CONF_WEB_SEARCH_USER_LOCATION,
|
||||||
|
DEFAULT_AI_TASK_NAME,
|
||||||
DEFAULT_CONVERSATION_NAME,
|
DEFAULT_CONVERSATION_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
RECOMMENDED_REASONING_EFFORT,
|
RECOMMENDED_REASONING_EFFORT,
|
||||||
RECOMMENDED_TEMPERATURE,
|
RECOMMENDED_TEMPERATURE,
|
||||||
@@ -76,12 +79,6 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
RECOMMENDED_OPTIONS = {
|
|
||||||
CONF_RECOMMENDED: True,
|
|
||||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
|
||||||
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
||||||
"""Validate the user input allows us to connect.
|
"""Validate the user input allows us to connect.
|
||||||
@@ -126,10 +123,16 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
subentries=[
|
subentries=[
|
||||||
{
|
{
|
||||||
"subentry_type": "conversation",
|
"subentry_type": "conversation",
|
||||||
"data": RECOMMENDED_OPTIONS,
|
"data": RECOMMENDED_CONVERSATION_OPTIONS,
|
||||||
"title": DEFAULT_CONVERSATION_NAME,
|
"title": DEFAULT_CONVERSATION_NAME,
|
||||||
"unique_id": None,
|
"unique_id": None,
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"subentry_type": "ai_task",
|
||||||
|
"data": RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
|
"title": DEFAULT_AI_TASK_NAME,
|
||||||
|
"unique_id": None,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -143,10 +146,13 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
cls, config_entry: ConfigEntry
|
cls, config_entry: ConfigEntry
|
||||||
) -> dict[str, type[ConfigSubentryFlow]]:
|
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||||
"""Return subentries supported by this integration."""
|
"""Return subentries supported by this integration."""
|
||||||
return {"conversation": ConversationSubentryFlowHandler}
|
return {
|
||||||
|
"conversation": LLMSubentryFlowHandler,
|
||||||
|
"ai_task": LLMSubentryFlowHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
class LLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||||
"""Flow for managing conversation subentries."""
|
"""Flow for managing conversation subentries."""
|
||||||
|
|
||||||
last_rendered_recommended = False
|
last_rendered_recommended = False
|
||||||
@@ -158,7 +164,10 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
) -> SubentryFlowResult:
|
) -> SubentryFlowResult:
|
||||||
"""Add a subentry."""
|
"""Add a subentry."""
|
||||||
self.is_new = True
|
self.is_new = True
|
||||||
self.options = RECOMMENDED_OPTIONS.copy()
|
if self._subentry_type == "ai_task":
|
||||||
|
self.options = RECOMMENDED_AI_TASK_OPTIONS.copy()
|
||||||
|
else:
|
||||||
|
self.options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||||
return await self.async_step_init()
|
return await self.async_step_init()
|
||||||
|
|
||||||
async def async_step_reconfigure(
|
async def async_step_reconfigure(
|
||||||
@@ -190,28 +199,34 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
step_schema: VolDictType = {}
|
step_schema: VolDictType = {}
|
||||||
|
|
||||||
if self.is_new:
|
if self.is_new:
|
||||||
step_schema[vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME)] = (
|
if CONF_NAME in options:
|
||||||
str
|
default_name = options[CONF_NAME]
|
||||||
|
elif self._subentry_type == "ai_task":
|
||||||
|
default_name = DEFAULT_AI_TASK_NAME
|
||||||
|
else:
|
||||||
|
default_name = DEFAULT_CONVERSATION_NAME
|
||||||
|
step_schema[vol.Required(CONF_NAME, default=default_name)] = str
|
||||||
|
|
||||||
|
if self._subentry_type == "conversation":
|
||||||
|
step_schema.update(
|
||||||
|
{
|
||||||
|
vol.Optional(
|
||||||
|
CONF_PROMPT,
|
||||||
|
description={
|
||||||
|
"suggested_value": options.get(
|
||||||
|
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
||||||
|
)
|
||||||
|
},
|
||||||
|
): TemplateSelector(),
|
||||||
|
vol.Optional(CONF_LLM_HASS_API): SelectSelector(
|
||||||
|
SelectSelectorConfig(options=hass_apis, multiple=True)
|
||||||
|
),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
step_schema.update(
|
step_schema[
|
||||||
{
|
vol.Required(CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False))
|
||||||
vol.Optional(
|
] = bool
|
||||||
CONF_PROMPT,
|
|
||||||
description={
|
|
||||||
"suggested_value": options.get(
|
|
||||||
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
|
||||||
)
|
|
||||||
},
|
|
||||||
): TemplateSelector(),
|
|
||||||
vol.Optional(CONF_LLM_HASS_API): SelectSelector(
|
|
||||||
SelectSelectorConfig(options=hass_apis, multiple=True)
|
|
||||||
),
|
|
||||||
vol.Required(
|
|
||||||
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
|
||||||
): bool,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
if not user_input.get(CONF_LLM_HASS_API):
|
if not user_input.get(CONF_LLM_HASS_API):
|
||||||
|
@@ -2,10 +2,14 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
|
from homeassistant.helpers import llm
|
||||||
|
|
||||||
DOMAIN = "openai_conversation"
|
DOMAIN = "openai_conversation"
|
||||||
LOGGER: logging.Logger = logging.getLogger(__package__)
|
LOGGER: logging.Logger = logging.getLogger(__package__)
|
||||||
|
|
||||||
DEFAULT_CONVERSATION_NAME = "OpenAI Conversation"
|
DEFAULT_CONVERSATION_NAME = "OpenAI Conversation"
|
||||||
|
DEFAULT_AI_TASK_NAME = "OpenAI AI Task"
|
||||||
|
|
||||||
CONF_CHAT_MODEL = "chat_model"
|
CONF_CHAT_MODEL = "chat_model"
|
||||||
CONF_FILENAMES = "filenames"
|
CONF_FILENAMES = "filenames"
|
||||||
@@ -32,6 +36,16 @@ RECOMMENDED_WEB_SEARCH = False
|
|||||||
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE = "medium"
|
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE = "medium"
|
||||||
RECOMMENDED_WEB_SEARCH_USER_LOCATION = False
|
RECOMMENDED_WEB_SEARCH_USER_LOCATION = False
|
||||||
|
|
||||||
|
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||||
|
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||||
|
}
|
||||||
|
|
||||||
|
RECOMMENDED_AI_TASK_OPTIONS = {
|
||||||
|
CONF_RECOMMENDED: True,
|
||||||
|
}
|
||||||
|
|
||||||
UNSUPPORTED_MODELS: list[str] = [
|
UNSUPPORTED_MODELS: list[str] = [
|
||||||
"o1-mini",
|
"o1-mini",
|
||||||
"o1-mini-2024-09-12",
|
"o1-mini-2024-09-12",
|
||||||
|
@@ -1,74 +1,17 @@
|
|||||||
"""Conversation support for OpenAI."""
|
"""Conversation support for OpenAI."""
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from typing import Literal
|
||||||
import json
|
|
||||||
from typing import Any, Literal, cast
|
|
||||||
|
|
||||||
import openai
|
|
||||||
from openai._streaming import AsyncStream
|
|
||||||
from openai.types.responses import (
|
|
||||||
EasyInputMessageParam,
|
|
||||||
FunctionToolParam,
|
|
||||||
ResponseCompletedEvent,
|
|
||||||
ResponseErrorEvent,
|
|
||||||
ResponseFailedEvent,
|
|
||||||
ResponseFunctionCallArgumentsDeltaEvent,
|
|
||||||
ResponseFunctionCallArgumentsDoneEvent,
|
|
||||||
ResponseFunctionToolCall,
|
|
||||||
ResponseFunctionToolCallParam,
|
|
||||||
ResponseIncompleteEvent,
|
|
||||||
ResponseInputParam,
|
|
||||||
ResponseOutputItemAddedEvent,
|
|
||||||
ResponseOutputItemDoneEvent,
|
|
||||||
ResponseOutputMessage,
|
|
||||||
ResponseOutputMessageParam,
|
|
||||||
ResponseReasoningItem,
|
|
||||||
ResponseReasoningItemParam,
|
|
||||||
ResponseStreamEvent,
|
|
||||||
ResponseTextDeltaEvent,
|
|
||||||
ToolParam,
|
|
||||||
WebSearchToolParam,
|
|
||||||
)
|
|
||||||
from openai.types.responses.response_input_param import FunctionCallOutput
|
|
||||||
from openai.types.responses.web_search_tool_param import UserLocation
|
|
||||||
from voluptuous_openapi import convert
|
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, conversation
|
from homeassistant.components import assist_pipeline, conversation
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.helpers import intent
|
||||||
from homeassistant.helpers import device_registry as dr, intent, llm
|
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
|
||||||
from . import OpenAIConfigEntry
|
from . import OpenAIConfigEntry
|
||||||
from .const import (
|
from .const import CONF_PROMPT, DEFAULT_CONVERSATION_NAME, DOMAIN
|
||||||
CONF_CHAT_MODEL,
|
from .entity import OpenAILLMBaseEntity
|
||||||
CONF_MAX_TOKENS,
|
|
||||||
CONF_PROMPT,
|
|
||||||
CONF_REASONING_EFFORT,
|
|
||||||
CONF_TEMPERATURE,
|
|
||||||
CONF_TOP_P,
|
|
||||||
CONF_WEB_SEARCH,
|
|
||||||
CONF_WEB_SEARCH_CITY,
|
|
||||||
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
|
||||||
CONF_WEB_SEARCH_COUNTRY,
|
|
||||||
CONF_WEB_SEARCH_REGION,
|
|
||||||
CONF_WEB_SEARCH_TIMEZONE,
|
|
||||||
CONF_WEB_SEARCH_USER_LOCATION,
|
|
||||||
DEFAULT_CONVERSATION_NAME,
|
|
||||||
DOMAIN,
|
|
||||||
LOGGER,
|
|
||||||
RECOMMENDED_CHAT_MODEL,
|
|
||||||
RECOMMENDED_MAX_TOKENS,
|
|
||||||
RECOMMENDED_REASONING_EFFORT,
|
|
||||||
RECOMMENDED_TEMPERATURE,
|
|
||||||
RECOMMENDED_TOP_P,
|
|
||||||
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Max number of back and forth with the LLM to generate a response
|
|
||||||
MAX_TOOL_ITERATIONS = 10
|
|
||||||
|
|
||||||
|
|
||||||
async def async_setup_entry(
|
async def async_setup_entry(
|
||||||
@@ -87,152 +30,10 @@ async def async_setup_entry(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _format_tool(
|
|
||||||
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
|
||||||
) -> FunctionToolParam:
|
|
||||||
"""Format tool specification."""
|
|
||||||
return FunctionToolParam(
|
|
||||||
type="function",
|
|
||||||
name=tool.name,
|
|
||||||
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
|
||||||
description=tool.description,
|
|
||||||
strict=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_content_to_param(
|
|
||||||
content: conversation.Content,
|
|
||||||
) -> ResponseInputParam:
|
|
||||||
"""Convert any native chat message for this agent to the native format."""
|
|
||||||
messages: ResponseInputParam = []
|
|
||||||
if isinstance(content, conversation.ToolResultContent):
|
|
||||||
return [
|
|
||||||
FunctionCallOutput(
|
|
||||||
type="function_call_output",
|
|
||||||
call_id=content.tool_call_id,
|
|
||||||
output=json.dumps(content.tool_result),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
if content.content:
|
|
||||||
role: Literal["user", "assistant", "system", "developer"] = content.role
|
|
||||||
if role == "system":
|
|
||||||
role = "developer"
|
|
||||||
messages.append(
|
|
||||||
EasyInputMessageParam(type="message", role=role, content=content.content)
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
|
|
||||||
messages.extend(
|
|
||||||
ResponseFunctionToolCallParam(
|
|
||||||
type="function_call",
|
|
||||||
name=tool_call.tool_name,
|
|
||||||
arguments=json.dumps(tool_call.tool_args),
|
|
||||||
call_id=tool_call.id,
|
|
||||||
)
|
|
||||||
for tool_call in content.tool_calls
|
|
||||||
)
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
async def _transform_stream(
|
|
||||||
chat_log: conversation.ChatLog,
|
|
||||||
result: AsyncStream[ResponseStreamEvent],
|
|
||||||
messages: ResponseInputParam,
|
|
||||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
|
||||||
"""Transform an OpenAI delta stream into HA format."""
|
|
||||||
async for event in result:
|
|
||||||
LOGGER.debug("Received event: %s", event)
|
|
||||||
|
|
||||||
if isinstance(event, ResponseOutputItemAddedEvent):
|
|
||||||
if isinstance(event.item, ResponseOutputMessage):
|
|
||||||
yield {"role": event.item.role}
|
|
||||||
elif isinstance(event.item, ResponseFunctionToolCall):
|
|
||||||
# OpenAI has tool calls as individual events
|
|
||||||
# while HA puts tool calls inside the assistant message.
|
|
||||||
# We turn them into individual assistant content for HA
|
|
||||||
# to ensure that tools are called as soon as possible.
|
|
||||||
yield {"role": "assistant"}
|
|
||||||
current_tool_call = event.item
|
|
||||||
elif isinstance(event, ResponseOutputItemDoneEvent):
|
|
||||||
item = event.item.model_dump()
|
|
||||||
item.pop("status", None)
|
|
||||||
if isinstance(event.item, ResponseReasoningItem):
|
|
||||||
messages.append(cast(ResponseReasoningItemParam, item))
|
|
||||||
elif isinstance(event.item, ResponseOutputMessage):
|
|
||||||
messages.append(cast(ResponseOutputMessageParam, item))
|
|
||||||
elif isinstance(event.item, ResponseFunctionToolCall):
|
|
||||||
messages.append(cast(ResponseFunctionToolCallParam, item))
|
|
||||||
elif isinstance(event, ResponseTextDeltaEvent):
|
|
||||||
yield {"content": event.delta}
|
|
||||||
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
|
|
||||||
current_tool_call.arguments += event.delta
|
|
||||||
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
|
|
||||||
current_tool_call.status = "completed"
|
|
||||||
yield {
|
|
||||||
"tool_calls": [
|
|
||||||
llm.ToolInput(
|
|
||||||
id=current_tool_call.call_id,
|
|
||||||
tool_name=current_tool_call.name,
|
|
||||||
tool_args=json.loads(current_tool_call.arguments),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
}
|
|
||||||
elif isinstance(event, ResponseCompletedEvent):
|
|
||||||
if event.response.usage is not None:
|
|
||||||
chat_log.async_trace(
|
|
||||||
{
|
|
||||||
"stats": {
|
|
||||||
"input_tokens": event.response.usage.input_tokens,
|
|
||||||
"output_tokens": event.response.usage.output_tokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif isinstance(event, ResponseIncompleteEvent):
|
|
||||||
if event.response.usage is not None:
|
|
||||||
chat_log.async_trace(
|
|
||||||
{
|
|
||||||
"stats": {
|
|
||||||
"input_tokens": event.response.usage.input_tokens,
|
|
||||||
"output_tokens": event.response.usage.output_tokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
event.response.incomplete_details
|
|
||||||
and event.response.incomplete_details.reason
|
|
||||||
):
|
|
||||||
reason: str = event.response.incomplete_details.reason
|
|
||||||
else:
|
|
||||||
reason = "unknown reason"
|
|
||||||
|
|
||||||
if reason == "max_output_tokens":
|
|
||||||
reason = "max output tokens reached"
|
|
||||||
elif reason == "content_filter":
|
|
||||||
reason = "content filter triggered"
|
|
||||||
|
|
||||||
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
|
|
||||||
elif isinstance(event, ResponseFailedEvent):
|
|
||||||
if event.response.usage is not None:
|
|
||||||
chat_log.async_trace(
|
|
||||||
{
|
|
||||||
"stats": {
|
|
||||||
"input_tokens": event.response.usage.input_tokens,
|
|
||||||
"output_tokens": event.response.usage.output_tokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
reason = "unknown reason"
|
|
||||||
if event.response.error is not None:
|
|
||||||
reason = event.response.error.message
|
|
||||||
raise HomeAssistantError(f"OpenAI response failed: {reason}")
|
|
||||||
elif isinstance(event, ResponseErrorEvent):
|
|
||||||
raise HomeAssistantError(f"OpenAI response error: {event.message}")
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIConversationEntity(
|
class OpenAIConversationEntity(
|
||||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
conversation.ConversationEntity,
|
||||||
|
conversation.AbstractConversationAgent,
|
||||||
|
OpenAILLMBaseEntity,
|
||||||
):
|
):
|
||||||
"""OpenAI conversation agent."""
|
"""OpenAI conversation agent."""
|
||||||
|
|
||||||
@@ -240,17 +41,8 @@ class OpenAIConversationEntity(
|
|||||||
|
|
||||||
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
|
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
|
||||||
"""Initialize the agent."""
|
"""Initialize the agent."""
|
||||||
self.entry = entry
|
super().__init__(entry, subentry)
|
||||||
self.subentry = subentry
|
|
||||||
self._attr_name = subentry.title or DEFAULT_CONVERSATION_NAME
|
self._attr_name = subentry.title or DEFAULT_CONVERSATION_NAME
|
||||||
self._attr_unique_id = subentry.subentry_id
|
|
||||||
self._attr_device_info = dr.DeviceInfo(
|
|
||||||
identifiers={(DOMAIN, entry.entry_id)},
|
|
||||||
name=entry.title,
|
|
||||||
manufacturer="OpenAI",
|
|
||||||
model="ChatGPT",
|
|
||||||
entry_type=dr.DeviceEntryType.SERVICE,
|
|
||||||
)
|
|
||||||
if self.subentry.data.get(CONF_LLM_HASS_API):
|
if self.subentry.data.get(CONF_LLM_HASS_API):
|
||||||
self._attr_supported_features = (
|
self._attr_supported_features = (
|
||||||
conversation.ConversationEntityFeature.CONTROL
|
conversation.ConversationEntityFeature.CONTROL
|
||||||
@@ -306,91 +98,6 @@ class OpenAIConversationEntity(
|
|||||||
continue_conversation=chat_log.continue_conversation,
|
continue_conversation=chat_log.continue_conversation,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_handle_chat_log(
|
|
||||||
self,
|
|
||||||
chat_log: conversation.ChatLog,
|
|
||||||
) -> None:
|
|
||||||
"""Generate an answer for the chat log."""
|
|
||||||
options = self.subentry.data
|
|
||||||
|
|
||||||
tools: list[ToolParam] | None = None
|
|
||||||
if chat_log.llm_api:
|
|
||||||
tools = [
|
|
||||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
|
||||||
for tool in chat_log.llm_api.tools
|
|
||||||
]
|
|
||||||
|
|
||||||
if options.get(CONF_WEB_SEARCH):
|
|
||||||
web_search = WebSearchToolParam(
|
|
||||||
type="web_search_preview",
|
|
||||||
search_context_size=options.get(
|
|
||||||
CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
|
|
||||||
web_search["user_location"] = UserLocation(
|
|
||||||
type="approximate",
|
|
||||||
city=options.get(CONF_WEB_SEARCH_CITY, ""),
|
|
||||||
region=options.get(CONF_WEB_SEARCH_REGION, ""),
|
|
||||||
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
|
|
||||||
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
|
|
||||||
)
|
|
||||||
if tools is None:
|
|
||||||
tools = []
|
|
||||||
tools.append(web_search)
|
|
||||||
|
|
||||||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
|
||||||
messages = [
|
|
||||||
m
|
|
||||||
for content in chat_log.content
|
|
||||||
for m in _convert_content_to_param(content)
|
|
||||||
]
|
|
||||||
|
|
||||||
client = self.entry.runtime_data
|
|
||||||
|
|
||||||
# To prevent infinite loops, we limit the number of iterations
|
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
|
||||||
model_args = {
|
|
||||||
"model": model,
|
|
||||||
"input": messages,
|
|
||||||
"max_output_tokens": options.get(
|
|
||||||
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
|
|
||||||
),
|
|
||||||
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
|
||||||
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
|
||||||
"user": chat_log.conversation_id,
|
|
||||||
"stream": True,
|
|
||||||
}
|
|
||||||
if tools:
|
|
||||||
model_args["tools"] = tools
|
|
||||||
|
|
||||||
if model.startswith("o"):
|
|
||||||
model_args["reasoning"] = {
|
|
||||||
"effort": options.get(
|
|
||||||
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
|
||||||
)
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
model_args["store"] = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await client.responses.create(**model_args)
|
|
||||||
except openai.RateLimitError as err:
|
|
||||||
LOGGER.error("Rate limited by OpenAI: %s", err)
|
|
||||||
raise HomeAssistantError("Rate limited or insufficient funds") from err
|
|
||||||
except openai.OpenAIError as err:
|
|
||||||
LOGGER.error("Error talking to OpenAI: %s", err)
|
|
||||||
raise HomeAssistantError("Error talking to OpenAI") from err
|
|
||||||
|
|
||||||
async for content in chat_log.async_add_delta_content_stream(
|
|
||||||
self.entity_id, _transform_stream(chat_log, result, messages)
|
|
||||||
):
|
|
||||||
if not isinstance(content, conversation.AssistantContent):
|
|
||||||
messages.extend(_convert_content_to_param(content))
|
|
||||||
|
|
||||||
if not chat_log.unresponded_tool_results:
|
|
||||||
break
|
|
||||||
|
|
||||||
async def _async_entry_update_listener(
|
async def _async_entry_update_listener(
|
||||||
self, hass: HomeAssistant, entry: ConfigEntry
|
self, hass: HomeAssistant, entry: ConfigEntry
|
||||||
) -> None:
|
) -> None:
|
||||||
|
313
homeassistant/components/openai_conversation/entity.py
Normal file
313
homeassistant/components/openai_conversation/entity.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
"""Base class for OpenAI Conversation entities."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator, Callable
|
||||||
|
import json
|
||||||
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from openai._streaming import AsyncStream
|
||||||
|
from openai.types.responses import (
|
||||||
|
EasyInputMessageParam,
|
||||||
|
FunctionToolParam,
|
||||||
|
ResponseCompletedEvent,
|
||||||
|
ResponseErrorEvent,
|
||||||
|
ResponseFailedEvent,
|
||||||
|
ResponseFunctionCallArgumentsDeltaEvent,
|
||||||
|
ResponseFunctionCallArgumentsDoneEvent,
|
||||||
|
ResponseFunctionToolCall,
|
||||||
|
ResponseFunctionToolCallParam,
|
||||||
|
ResponseIncompleteEvent,
|
||||||
|
ResponseInputParam,
|
||||||
|
ResponseOutputItemAddedEvent,
|
||||||
|
ResponseOutputItemDoneEvent,
|
||||||
|
ResponseOutputMessage,
|
||||||
|
ResponseOutputMessageParam,
|
||||||
|
ResponseReasoningItem,
|
||||||
|
ResponseReasoningItemParam,
|
||||||
|
ResponseStreamEvent,
|
||||||
|
ResponseTextDeltaEvent,
|
||||||
|
ToolParam,
|
||||||
|
WebSearchToolParam,
|
||||||
|
)
|
||||||
|
from openai.types.responses.response_input_param import FunctionCallOutput
|
||||||
|
from openai.types.responses.web_search_tool_param import UserLocation
|
||||||
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
|
from homeassistant.components import conversation
|
||||||
|
from homeassistant.config_entries import ConfigSubentry
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers import device_registry as dr, llm
|
||||||
|
from homeassistant.helpers.entity import Entity
|
||||||
|
|
||||||
|
from . import OpenAIConfigEntry
|
||||||
|
from .const import (
|
||||||
|
CONF_CHAT_MODEL,
|
||||||
|
CONF_MAX_TOKENS,
|
||||||
|
CONF_REASONING_EFFORT,
|
||||||
|
CONF_TEMPERATURE,
|
||||||
|
CONF_TOP_P,
|
||||||
|
CONF_WEB_SEARCH,
|
||||||
|
CONF_WEB_SEARCH_CITY,
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE,
|
||||||
|
CONF_WEB_SEARCH_COUNTRY,
|
||||||
|
CONF_WEB_SEARCH_REGION,
|
||||||
|
CONF_WEB_SEARCH_TIMEZONE,
|
||||||
|
CONF_WEB_SEARCH_USER_LOCATION,
|
||||||
|
DOMAIN,
|
||||||
|
LOGGER,
|
||||||
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_MAX_TOKENS,
|
||||||
|
RECOMMENDED_REASONING_EFFORT,
|
||||||
|
RECOMMENDED_TEMPERATURE,
|
||||||
|
RECOMMENDED_TOP_P,
|
||||||
|
RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Max number of back and forth with the LLM to generate a response
|
||||||
|
MAX_TOOL_ITERATIONS = 10
|
||||||
|
|
||||||
|
|
||||||
|
def _format_tool(
|
||||||
|
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
||||||
|
) -> FunctionToolParam:
|
||||||
|
"""Format tool specification."""
|
||||||
|
return FunctionToolParam(
|
||||||
|
type="function",
|
||||||
|
name=tool.name,
|
||||||
|
parameters=convert(tool.parameters, custom_serializer=custom_serializer),
|
||||||
|
description=tool.description,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_content_to_param(
|
||||||
|
content: conversation.Content,
|
||||||
|
) -> ResponseInputParam:
|
||||||
|
"""Convert any native chat message for this agent to the native format."""
|
||||||
|
messages: ResponseInputParam = []
|
||||||
|
if isinstance(content, conversation.ToolResultContent):
|
||||||
|
return [
|
||||||
|
FunctionCallOutput(
|
||||||
|
type="function_call_output",
|
||||||
|
call_id=content.tool_call_id,
|
||||||
|
output=json.dumps(content.tool_result),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if content.content:
|
||||||
|
role: Literal["user", "assistant", "system", "developer"] = content.role
|
||||||
|
if role == "system":
|
||||||
|
role = "developer"
|
||||||
|
messages.append(
|
||||||
|
EasyInputMessageParam(type="message", role=role, content=content.content)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(content, conversation.AssistantContent) and content.tool_calls:
|
||||||
|
messages.extend(
|
||||||
|
ResponseFunctionToolCallParam(
|
||||||
|
type="function_call",
|
||||||
|
name=tool_call.tool_name,
|
||||||
|
arguments=json.dumps(tool_call.tool_args),
|
||||||
|
call_id=tool_call.id,
|
||||||
|
)
|
||||||
|
for tool_call in content.tool_calls
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
async def _transform_stream(
|
||||||
|
chat_log: conversation.ChatLog,
|
||||||
|
result: AsyncStream[ResponseStreamEvent],
|
||||||
|
messages: ResponseInputParam,
|
||||||
|
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||||
|
"""Transform an OpenAI delta stream into HA format."""
|
||||||
|
async for event in result:
|
||||||
|
LOGGER.debug("Received event: %s", event)
|
||||||
|
|
||||||
|
if isinstance(event, ResponseOutputItemAddedEvent):
|
||||||
|
if isinstance(event.item, ResponseOutputMessage):
|
||||||
|
yield {"role": event.item.role}
|
||||||
|
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||||
|
# OpenAI has tool calls as individual events
|
||||||
|
# while HA puts tool calls inside the assistant message.
|
||||||
|
# We turn them into individual assistant content for HA
|
||||||
|
# to ensure that tools are called as soon as possible.
|
||||||
|
yield {"role": "assistant"}
|
||||||
|
current_tool_call = event.item
|
||||||
|
elif isinstance(event, ResponseOutputItemDoneEvent):
|
||||||
|
item = event.item.model_dump()
|
||||||
|
item.pop("status", None)
|
||||||
|
if isinstance(event.item, ResponseReasoningItem):
|
||||||
|
messages.append(cast(ResponseReasoningItemParam, item))
|
||||||
|
elif isinstance(event.item, ResponseOutputMessage):
|
||||||
|
messages.append(cast(ResponseOutputMessageParam, item))
|
||||||
|
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||||
|
messages.append(cast(ResponseFunctionToolCallParam, item))
|
||||||
|
elif isinstance(event, ResponseTextDeltaEvent):
|
||||||
|
yield {"content": event.delta}
|
||||||
|
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
|
||||||
|
current_tool_call.arguments += event.delta
|
||||||
|
elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent):
|
||||||
|
current_tool_call.status = "completed"
|
||||||
|
yield {
|
||||||
|
"tool_calls": [
|
||||||
|
llm.ToolInput(
|
||||||
|
id=current_tool_call.call_id,
|
||||||
|
tool_name=current_tool_call.name,
|
||||||
|
tool_args=json.loads(current_tool_call.arguments),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
elif isinstance(event, ResponseCompletedEvent):
|
||||||
|
if event.response.usage is not None:
|
||||||
|
chat_log.async_trace(
|
||||||
|
{
|
||||||
|
"stats": {
|
||||||
|
"input_tokens": event.response.usage.input_tokens,
|
||||||
|
"output_tokens": event.response.usage.output_tokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif isinstance(event, ResponseIncompleteEvent):
|
||||||
|
if event.response.usage is not None:
|
||||||
|
chat_log.async_trace(
|
||||||
|
{
|
||||||
|
"stats": {
|
||||||
|
"input_tokens": event.response.usage.input_tokens,
|
||||||
|
"output_tokens": event.response.usage.output_tokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
event.response.incomplete_details
|
||||||
|
and event.response.incomplete_details.reason
|
||||||
|
):
|
||||||
|
reason: str = event.response.incomplete_details.reason
|
||||||
|
else:
|
||||||
|
reason = "unknown reason"
|
||||||
|
|
||||||
|
if reason == "max_output_tokens":
|
||||||
|
reason = "max output tokens reached"
|
||||||
|
elif reason == "content_filter":
|
||||||
|
reason = "content filter triggered"
|
||||||
|
|
||||||
|
raise HomeAssistantError(f"OpenAI response incomplete: {reason}")
|
||||||
|
elif isinstance(event, ResponseFailedEvent):
|
||||||
|
if event.response.usage is not None:
|
||||||
|
chat_log.async_trace(
|
||||||
|
{
|
||||||
|
"stats": {
|
||||||
|
"input_tokens": event.response.usage.input_tokens,
|
||||||
|
"output_tokens": event.response.usage.output_tokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
reason = "unknown reason"
|
||||||
|
if event.response.error is not None:
|
||||||
|
reason = event.response.error.message
|
||||||
|
raise HomeAssistantError(f"OpenAI response failed: {reason}")
|
||||||
|
elif isinstance(event, ResponseErrorEvent):
|
||||||
|
raise HomeAssistantError(f"OpenAI response error: {event.message}")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAILLMBaseEntity(Entity):
|
||||||
|
"""OpenAI conversation agent."""
|
||||||
|
|
||||||
|
def __init__(self, entry: OpenAIConfigEntry, subentry: ConfigSubentry) -> None:
|
||||||
|
"""Initialize the agent."""
|
||||||
|
self.entry = entry
|
||||||
|
self.subentry = subentry
|
||||||
|
self._attr_unique_id = subentry.subentry_id
|
||||||
|
self._attr_device_info = dr.DeviceInfo(
|
||||||
|
identifiers={(DOMAIN, entry.entry_id)},
|
||||||
|
name=entry.title,
|
||||||
|
manufacturer="OpenAI",
|
||||||
|
model="ChatGPT",
|
||||||
|
entry_type=dr.DeviceEntryType.SERVICE,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _async_handle_chat_log(
|
||||||
|
self,
|
||||||
|
chat_log: conversation.ChatLog,
|
||||||
|
) -> None:
|
||||||
|
"""Generate an answer for the chat log."""
|
||||||
|
options = self.subentry.data
|
||||||
|
|
||||||
|
tools: list[ToolParam] | None = None
|
||||||
|
if chat_log.llm_api:
|
||||||
|
tools = [
|
||||||
|
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||||
|
for tool in chat_log.llm_api.tools
|
||||||
|
]
|
||||||
|
|
||||||
|
if options.get(CONF_WEB_SEARCH):
|
||||||
|
web_search = WebSearchToolParam(
|
||||||
|
type="web_search_preview",
|
||||||
|
search_context_size=options.get(
|
||||||
|
CONF_WEB_SEARCH_CONTEXT_SIZE, RECOMMENDED_WEB_SEARCH_CONTEXT_SIZE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if options.get(CONF_WEB_SEARCH_USER_LOCATION):
|
||||||
|
web_search["user_location"] = UserLocation(
|
||||||
|
type="approximate",
|
||||||
|
city=options.get(CONF_WEB_SEARCH_CITY, ""),
|
||||||
|
region=options.get(CONF_WEB_SEARCH_REGION, ""),
|
||||||
|
country=options.get(CONF_WEB_SEARCH_COUNTRY, ""),
|
||||||
|
timezone=options.get(CONF_WEB_SEARCH_TIMEZONE, ""),
|
||||||
|
)
|
||||||
|
if tools is None:
|
||||||
|
tools = []
|
||||||
|
tools.append(web_search)
|
||||||
|
|
||||||
|
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||||
|
messages = [
|
||||||
|
m
|
||||||
|
for content in chat_log.content
|
||||||
|
for m in _convert_content_to_param(content)
|
||||||
|
]
|
||||||
|
|
||||||
|
client = self.entry.runtime_data
|
||||||
|
|
||||||
|
# To prevent infinite loops, we limit the number of iterations
|
||||||
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
|
model_args = {
|
||||||
|
"model": model,
|
||||||
|
"input": messages,
|
||||||
|
"max_output_tokens": options.get(
|
||||||
|
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
|
||||||
|
),
|
||||||
|
"top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||||
|
"temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||||
|
"user": chat_log.conversation_id,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
if tools:
|
||||||
|
model_args["tools"] = tools
|
||||||
|
|
||||||
|
if model.startswith("o"):
|
||||||
|
model_args["reasoning"] = {
|
||||||
|
"effort": options.get(
|
||||||
|
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
||||||
|
)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
model_args["store"] = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await client.responses.create(**model_args)
|
||||||
|
except openai.RateLimitError as err:
|
||||||
|
LOGGER.error("Rate limited by OpenAI: %s", err)
|
||||||
|
raise HomeAssistantError("Rate limited or insufficient funds") from err
|
||||||
|
except openai.OpenAIError as err:
|
||||||
|
LOGGER.error("Error talking to OpenAI: %s", err)
|
||||||
|
raise HomeAssistantError("Error talking to OpenAI") from err
|
||||||
|
|
||||||
|
async for content in chat_log.async_add_delta_content_stream(
|
||||||
|
self.entity_id, _transform_stream(chat_log, result, messages)
|
||||||
|
):
|
||||||
|
if not isinstance(content, conversation.AssistantContent):
|
||||||
|
messages.extend(_convert_content_to_param(content))
|
||||||
|
|
||||||
|
if not chat_log.unresponded_tool_results:
|
||||||
|
break
|
@@ -64,6 +64,51 @@
|
|||||||
"error": {
|
"error": {
|
||||||
"model_not_supported": "This model is not supported, please select a different model"
|
"model_not_supported": "This model is not supported, please select a different model"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"ai_task": {
|
||||||
|
"initiate_flow": {
|
||||||
|
"user": "Add AI task service",
|
||||||
|
"reconfigure": "Reconfigure AI task service"
|
||||||
|
},
|
||||||
|
"entry_type": "AI task service",
|
||||||
|
"step": {
|
||||||
|
"init": {
|
||||||
|
"data": {
|
||||||
|
"name": "[%key:common::config_flow::data::name%]",
|
||||||
|
"recommended": "[%key:component::openai_conversation::config_subentries::conversation::step::init::data::recommended%]"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"advanced": {
|
||||||
|
"title": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::title%]",
|
||||||
|
"data": {
|
||||||
|
"chat_model": "[%key:common::generic::model%]",
|
||||||
|
"max_tokens": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::max_tokens%]",
|
||||||
|
"temperature": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::temperature%]",
|
||||||
|
"top_p": "[%key:component::openai_conversation::config_subentries::conversation::step::advanced::data::top_p%]"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"title": "[%key:component::openai_conversation::config_subentries::conversation::step::model::title%]",
|
||||||
|
"data": {
|
||||||
|
"reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::reasoning_effort%]",
|
||||||
|
"web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::web_search%]",
|
||||||
|
"search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::search_context_size%]",
|
||||||
|
"user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data::user_location%]"
|
||||||
|
},
|
||||||
|
"data_description": {
|
||||||
|
"reasoning_effort": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::reasoning_effort%]",
|
||||||
|
"web_search": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::web_search%]",
|
||||||
|
"search_context_size": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::search_context_size%]",
|
||||||
|
"user_location": "[%key:component::openai_conversation::config_subentries::conversation::step::model::data_description::user_location%]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"abort": {
|
||||||
|
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"model_not_supported": "[%key:component::openai_conversation::config_subentries::conversation::error::model_not_supported%]"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"selector": {
|
"selector": {
|
||||||
|
@@ -3420,6 +3420,11 @@ class ConfigSubentryFlow(
|
|||||||
"""Return config entry id."""
|
"""Return config entry id."""
|
||||||
return self.handler[0]
|
return self.handler[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _subentry_type(self) -> str:
|
||||||
|
"""Return type of subentry we are editing/creating."""
|
||||||
|
return self.handler[1]
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _get_entry(self) -> ConfigEntry:
|
def _get_entry(self) -> ConfigEntry:
|
||||||
"""Return the config entry linked to the current context."""
|
"""Return the config entry linked to the current context."""
|
||||||
|
87
tests/components/openai_conversation/common.py
Normal file
87
tests/components/openai_conversation/common.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""Common utilities for OpenAI conversation tests."""
|
||||||
|
|
||||||
|
from openai.types.responses import (
|
||||||
|
ResponseContentPartAddedEvent,
|
||||||
|
ResponseContentPartDoneEvent,
|
||||||
|
ResponseOutputItemAddedEvent,
|
||||||
|
ResponseOutputItemDoneEvent,
|
||||||
|
ResponseOutputMessage,
|
||||||
|
ResponseOutputText,
|
||||||
|
ResponseStreamEvent,
|
||||||
|
ResponseTextDeltaEvent,
|
||||||
|
ResponseTextDoneEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_message_item(
|
||||||
|
id: str, text: str | list[str], output_index: int
|
||||||
|
) -> list[ResponseStreamEvent]:
|
||||||
|
"""Create a message item."""
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
content = ResponseOutputText(annotations=[], text="", type="output_text")
|
||||||
|
events = [
|
||||||
|
ResponseOutputItemAddedEvent(
|
||||||
|
item=ResponseOutputMessage(
|
||||||
|
id=id,
|
||||||
|
content=[],
|
||||||
|
type="message",
|
||||||
|
role="assistant",
|
||||||
|
status="in_progress",
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
type="response.output_item.added",
|
||||||
|
),
|
||||||
|
ResponseContentPartAddedEvent(
|
||||||
|
content_index=0,
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
part=content,
|
||||||
|
type="response.content_part.added",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
content.text = "".join(text)
|
||||||
|
events.extend(
|
||||||
|
ResponseTextDeltaEvent(
|
||||||
|
content_index=0,
|
||||||
|
delta=delta,
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
type="response.output_text.delta",
|
||||||
|
)
|
||||||
|
for delta in text
|
||||||
|
)
|
||||||
|
|
||||||
|
events.extend(
|
||||||
|
[
|
||||||
|
ResponseTextDoneEvent(
|
||||||
|
content_index=0,
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
text="".join(text),
|
||||||
|
type="response.output_text.done",
|
||||||
|
),
|
||||||
|
ResponseContentPartDoneEvent(
|
||||||
|
content_index=0,
|
||||||
|
item_id=id,
|
||||||
|
output_index=output_index,
|
||||||
|
part=content,
|
||||||
|
type="response.content_part.done",
|
||||||
|
),
|
||||||
|
ResponseOutputItemDoneEvent(
|
||||||
|
item=ResponseOutputMessage(
|
||||||
|
id=id,
|
||||||
|
content=[content],
|
||||||
|
role="assistant",
|
||||||
|
status="completed",
|
||||||
|
type="message",
|
||||||
|
),
|
||||||
|
output_index=output_index,
|
||||||
|
type="response.output_item.done",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return events
|
@@ -1,16 +1,35 @@
|
|||||||
"""Tests helpers."""
|
"""Tests helpers."""
|
||||||
|
|
||||||
from unittest.mock import patch
|
from collections.abc import Generator
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from openai.types import ResponseFormatText
|
||||||
|
from openai.types.responses import (
|
||||||
|
Response,
|
||||||
|
ResponseCompletedEvent,
|
||||||
|
ResponseCreatedEvent,
|
||||||
|
ResponseError,
|
||||||
|
ResponseErrorEvent,
|
||||||
|
ResponseFailedEvent,
|
||||||
|
ResponseIncompleteEvent,
|
||||||
|
ResponseInProgressEvent,
|
||||||
|
ResponseOutputItemDoneEvent,
|
||||||
|
ResponseTextConfig,
|
||||||
|
)
|
||||||
|
from openai.types.responses.response import IncompleteDetails
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.openai_conversation.const import DEFAULT_CONVERSATION_NAME
|
from homeassistant.components.openai_conversation.const import (
|
||||||
|
DEFAULT_AI_TASK_NAME,
|
||||||
|
DEFAULT_CONVERSATION_NAME,
|
||||||
|
)
|
||||||
from homeassistant.const import CONF_LLM_HASS_API
|
from homeassistant.const import CONF_LLM_HASS_API
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import llm
|
from homeassistant.helpers import llm
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
from tests.components.conversation import mock_chat_log # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -29,7 +48,13 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
|||||||
"subentry_type": "conversation",
|
"subentry_type": "conversation",
|
||||||
"title": DEFAULT_CONVERSATION_NAME,
|
"title": DEFAULT_CONVERSATION_NAME,
|
||||||
"unique_id": None,
|
"unique_id": None,
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"data": {},
|
||||||
|
"subentry_type": "ai_task",
|
||||||
|
"title": DEFAULT_AI_TASK_NAME,
|
||||||
|
"unique_id": None,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
entry.add_to_hass(hass)
|
entry.add_to_hass(hass)
|
||||||
@@ -65,3 +90,89 @@ async def mock_init_component(
|
|||||||
async def setup_ha(hass: HomeAssistant) -> None:
|
async def setup_ha(hass: HomeAssistant) -> None:
|
||||||
"""Set up Home Assistant."""
|
"""Set up Home Assistant."""
|
||||||
assert await async_setup_component(hass, "homeassistant", {})
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_create_stream() -> Generator[AsyncMock]:
|
||||||
|
"""Mock stream response."""
|
||||||
|
|
||||||
|
async def mock_generator(events, **kwargs):
|
||||||
|
response = Response(
|
||||||
|
id="resp_A",
|
||||||
|
created_at=1700000000,
|
||||||
|
error=None,
|
||||||
|
incomplete_details=None,
|
||||||
|
instructions=kwargs.get("instructions"),
|
||||||
|
metadata=kwargs.get("metadata", {}),
|
||||||
|
model=kwargs.get("model", "gpt-4o-mini"),
|
||||||
|
object="response",
|
||||||
|
output=[],
|
||||||
|
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
|
||||||
|
temperature=kwargs.get("temperature", 1.0),
|
||||||
|
tool_choice=kwargs.get("tool_choice", "auto"),
|
||||||
|
tools=kwargs.get("tools", []),
|
||||||
|
top_p=kwargs.get("top_p", 1.0),
|
||||||
|
max_output_tokens=kwargs.get("max_output_tokens", 100000),
|
||||||
|
previous_response_id=kwargs.get("previous_response_id"),
|
||||||
|
reasoning=kwargs.get("reasoning"),
|
||||||
|
status="in_progress",
|
||||||
|
text=kwargs.get(
|
||||||
|
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
|
||||||
|
),
|
||||||
|
truncation=kwargs.get("truncation", "disabled"),
|
||||||
|
usage=None,
|
||||||
|
user=kwargs.get("user"),
|
||||||
|
store=kwargs.get("store", True),
|
||||||
|
)
|
||||||
|
yield ResponseCreatedEvent(
|
||||||
|
response=response,
|
||||||
|
type="response.created",
|
||||||
|
)
|
||||||
|
yield ResponseInProgressEvent(
|
||||||
|
response=response,
|
||||||
|
type="response.in_progress",
|
||||||
|
)
|
||||||
|
response.status = "completed"
|
||||||
|
|
||||||
|
for value in events:
|
||||||
|
if isinstance(value, ResponseOutputItemDoneEvent):
|
||||||
|
response.output.append(value.item)
|
||||||
|
elif isinstance(value, IncompleteDetails):
|
||||||
|
response.status = "incomplete"
|
||||||
|
response.incomplete_details = value
|
||||||
|
break
|
||||||
|
if isinstance(value, ResponseError):
|
||||||
|
response.status = "failed"
|
||||||
|
response.error = value
|
||||||
|
break
|
||||||
|
|
||||||
|
yield value
|
||||||
|
|
||||||
|
if isinstance(value, ResponseErrorEvent):
|
||||||
|
return
|
||||||
|
|
||||||
|
if response.status == "incomplete":
|
||||||
|
yield ResponseIncompleteEvent(
|
||||||
|
response=response,
|
||||||
|
type="response.incomplete",
|
||||||
|
)
|
||||||
|
elif response.status == "failed":
|
||||||
|
yield ResponseFailedEvent(
|
||||||
|
response=response,
|
||||||
|
type="response.failed",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield ResponseCompletedEvent(
|
||||||
|
response=response,
|
||||||
|
type="response.completed",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"openai.resources.responses.AsyncResponses.create",
|
||||||
|
AsyncMock(),
|
||||||
|
) as mock_create:
|
||||||
|
mock_create.side_effect = lambda **kwargs: mock_generator(
|
||||||
|
mock_create.return_value.pop(0), **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
yield mock_create
|
||||||
|
33
tests/components/openai_conversation/test_ai_task.py
Normal file
33
tests/components/openai_conversation/test_ai_task.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""Test AI Task platform of OpenAI Conversation integration."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components import ai_task
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from .common import create_message_item
|
||||||
|
|
||||||
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_init_component")
|
||||||
|
async def test_ai_task_generate_text(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
mock_create_stream: AsyncMock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that AI task can generate text."""
|
||||||
|
entity_id = "ai_task.openai_ai_task"
|
||||||
|
mock_create_stream.return_value = [
|
||||||
|
create_message_item(id="msg_A", text="Hi there!", output_index=0)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ai_task.async_generate_text(
|
||||||
|
hass,
|
||||||
|
task_name="Test Task",
|
||||||
|
entity_id=entity_id,
|
||||||
|
instructions="Test prompt",
|
||||||
|
)
|
||||||
|
assert result.text == "Hi there!"
|
@@ -8,7 +8,6 @@ from openai.types.responses import Response, ResponseOutputMessage, ResponseOutp
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components.openai_conversation.config_flow import RECOMMENDED_OPTIONS
|
|
||||||
from homeassistant.components.openai_conversation.const import (
|
from homeassistant.components.openai_conversation.const import (
|
||||||
CONF_CHAT_MODEL,
|
CONF_CHAT_MODEL,
|
||||||
CONF_MAX_TOKENS,
|
CONF_MAX_TOKENS,
|
||||||
@@ -24,9 +23,12 @@ from homeassistant.components.openai_conversation.const import (
|
|||||||
CONF_WEB_SEARCH_REGION,
|
CONF_WEB_SEARCH_REGION,
|
||||||
CONF_WEB_SEARCH_TIMEZONE,
|
CONF_WEB_SEARCH_TIMEZONE,
|
||||||
CONF_WEB_SEARCH_USER_LOCATION,
|
CONF_WEB_SEARCH_USER_LOCATION,
|
||||||
|
DEFAULT_AI_TASK_NAME,
|
||||||
DEFAULT_CONVERSATION_NAME,
|
DEFAULT_CONVERSATION_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
|
RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
|
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||||
RECOMMENDED_MAX_TOKENS,
|
RECOMMENDED_MAX_TOKENS,
|
||||||
RECOMMENDED_TOP_P,
|
RECOMMENDED_TOP_P,
|
||||||
)
|
)
|
||||||
@@ -77,10 +79,16 @@ async def test_form(hass: HomeAssistant) -> None:
|
|||||||
assert result2["subentries"] == [
|
assert result2["subentries"] == [
|
||||||
{
|
{
|
||||||
"subentry_type": "conversation",
|
"subentry_type": "conversation",
|
||||||
"data": RECOMMENDED_OPTIONS,
|
"data": RECOMMENDED_CONVERSATION_OPTIONS,
|
||||||
"title": DEFAULT_CONVERSATION_NAME,
|
"title": DEFAULT_CONVERSATION_NAME,
|
||||||
"unique_id": None,
|
"unique_id": None,
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
"subentry_type": "ai_task",
|
||||||
|
"data": RECOMMENDED_AI_TASK_OPTIONS,
|
||||||
|
"title": DEFAULT_AI_TASK_NAME,
|
||||||
|
"unique_id": None,
|
||||||
|
},
|
||||||
]
|
]
|
||||||
assert len(mock_setup_entry.mock_calls) == 1
|
assert len(mock_setup_entry.mock_calls) == 1
|
||||||
|
|
||||||
@@ -104,19 +112,56 @@ async def test_creating_conversation_subentry(
|
|||||||
|
|
||||||
result2 = await hass.config_entries.subentries.async_configure(
|
result2 = await hass.config_entries.subentries.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
{"name": "My Custom Agent", **RECOMMENDED_OPTIONS},
|
{"name": "My Custom Agent", **RECOMMENDED_CONVERSATION_OPTIONS},
|
||||||
)
|
)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||||
assert result2["title"] == "My Custom Agent"
|
assert result2["title"] == "My Custom Agent"
|
||||||
|
|
||||||
processed_options = RECOMMENDED_OPTIONS.copy()
|
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||||
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
|
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
|
||||||
|
|
||||||
assert result2["data"] == processed_options
|
assert result2["data"] == processed_options
|
||||||
|
|
||||||
|
|
||||||
|
async def test_creating_ai_task_subentry(
|
||||||
|
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||||
|
) -> None:
|
||||||
|
"""Test creating an AI task subentry."""
|
||||||
|
with patch("openai.resources.models.AsyncModels.list"):
|
||||||
|
result = await hass.config_entries.subentries.async_init(
|
||||||
|
(mock_config_entry.entry_id, "ai_task"),
|
||||||
|
context={"source": config_entries.SOURCE_USER},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] is FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "init"
|
||||||
|
assert not result["errors"]
|
||||||
|
|
||||||
|
old_subentries = set(mock_config_entry.subentries)
|
||||||
|
|
||||||
|
with patch("openai.resources.models.AsyncModels.list"):
|
||||||
|
result2 = await hass.config_entries.subentries.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{"name": "My AI Task", **RECOMMENDED_AI_TASK_OPTIONS},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||||
|
assert result2["title"] == "My AI Task"
|
||||||
|
assert result2["data"] == RECOMMENDED_AI_TASK_OPTIONS
|
||||||
|
|
||||||
|
assert len(mock_config_entry.subentries) == 3
|
||||||
|
|
||||||
|
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||||
|
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||||
|
|
||||||
|
assert new_subentry.subentry_type == "ai_task"
|
||||||
|
assert new_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
||||||
|
assert new_subentry.title == "My AI Task"
|
||||||
|
|
||||||
|
|
||||||
async def test_subentry_recommended(
|
async def test_subentry_recommended(
|
||||||
hass: HomeAssistant, mock_config_entry, mock_init_component
|
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@@ -1,35 +1,20 @@
|
|||||||
"""Tests for the OpenAI integration."""
|
"""Tests for the OpenAI integration."""
|
||||||
|
|
||||||
from collections.abc import Generator
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from openai import AuthenticationError, RateLimitError
|
from openai import AuthenticationError, RateLimitError
|
||||||
from openai.types import ResponseFormatText
|
|
||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
Response,
|
|
||||||
ResponseCompletedEvent,
|
|
||||||
ResponseContentPartAddedEvent,
|
|
||||||
ResponseContentPartDoneEvent,
|
|
||||||
ResponseCreatedEvent,
|
|
||||||
ResponseError,
|
ResponseError,
|
||||||
ResponseErrorEvent,
|
ResponseErrorEvent,
|
||||||
ResponseFailedEvent,
|
|
||||||
ResponseFunctionCallArgumentsDeltaEvent,
|
ResponseFunctionCallArgumentsDeltaEvent,
|
||||||
ResponseFunctionCallArgumentsDoneEvent,
|
ResponseFunctionCallArgumentsDoneEvent,
|
||||||
ResponseFunctionToolCall,
|
ResponseFunctionToolCall,
|
||||||
ResponseFunctionWebSearch,
|
ResponseFunctionWebSearch,
|
||||||
ResponseIncompleteEvent,
|
|
||||||
ResponseInProgressEvent,
|
|
||||||
ResponseOutputItemAddedEvent,
|
ResponseOutputItemAddedEvent,
|
||||||
ResponseOutputItemDoneEvent,
|
ResponseOutputItemDoneEvent,
|
||||||
ResponseOutputMessage,
|
|
||||||
ResponseOutputText,
|
|
||||||
ResponseReasoningItem,
|
ResponseReasoningItem,
|
||||||
ResponseStreamEvent,
|
ResponseStreamEvent,
|
||||||
ResponseTextConfig,
|
|
||||||
ResponseTextDeltaEvent,
|
|
||||||
ResponseTextDoneEvent,
|
|
||||||
ResponseWebSearchCallCompletedEvent,
|
ResponseWebSearchCallCompletedEvent,
|
||||||
ResponseWebSearchCallInProgressEvent,
|
ResponseWebSearchCallInProgressEvent,
|
||||||
ResponseWebSearchCallSearchingEvent,
|
ResponseWebSearchCallSearchingEvent,
|
||||||
@@ -54,6 +39,8 @@ from homeassistant.core import Context, HomeAssistant
|
|||||||
from homeassistant.helpers import intent
|
from homeassistant.helpers import intent
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from .common import create_message_item
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
from tests.components.conversation import (
|
from tests.components.conversation import (
|
||||||
MockChatLog,
|
MockChatLog,
|
||||||
@@ -61,92 +48,6 @@ from tests.components.conversation import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_create_stream() -> Generator[AsyncMock]:
|
|
||||||
"""Mock stream response."""
|
|
||||||
|
|
||||||
async def mock_generator(events, **kwargs):
|
|
||||||
response = Response(
|
|
||||||
id="resp_A",
|
|
||||||
created_at=1700000000,
|
|
||||||
error=None,
|
|
||||||
incomplete_details=None,
|
|
||||||
instructions=kwargs.get("instructions"),
|
|
||||||
metadata=kwargs.get("metadata", {}),
|
|
||||||
model=kwargs.get("model", "gpt-4o-mini"),
|
|
||||||
object="response",
|
|
||||||
output=[],
|
|
||||||
parallel_tool_calls=kwargs.get("parallel_tool_calls", True),
|
|
||||||
temperature=kwargs.get("temperature", 1.0),
|
|
||||||
tool_choice=kwargs.get("tool_choice", "auto"),
|
|
||||||
tools=kwargs.get("tools"),
|
|
||||||
top_p=kwargs.get("top_p", 1.0),
|
|
||||||
max_output_tokens=kwargs.get("max_output_tokens", 100000),
|
|
||||||
previous_response_id=kwargs.get("previous_response_id"),
|
|
||||||
reasoning=kwargs.get("reasoning"),
|
|
||||||
status="in_progress",
|
|
||||||
text=kwargs.get(
|
|
||||||
"text", ResponseTextConfig(format=ResponseFormatText(type="text"))
|
|
||||||
),
|
|
||||||
truncation=kwargs.get("truncation", "disabled"),
|
|
||||||
usage=None,
|
|
||||||
user=kwargs.get("user"),
|
|
||||||
store=kwargs.get("store", True),
|
|
||||||
)
|
|
||||||
yield ResponseCreatedEvent(
|
|
||||||
response=response,
|
|
||||||
type="response.created",
|
|
||||||
)
|
|
||||||
yield ResponseInProgressEvent(
|
|
||||||
response=response,
|
|
||||||
type="response.in_progress",
|
|
||||||
)
|
|
||||||
response.status = "completed"
|
|
||||||
|
|
||||||
for value in events:
|
|
||||||
if isinstance(value, ResponseOutputItemDoneEvent):
|
|
||||||
response.output.append(value.item)
|
|
||||||
elif isinstance(value, IncompleteDetails):
|
|
||||||
response.status = "incomplete"
|
|
||||||
response.incomplete_details = value
|
|
||||||
break
|
|
||||||
if isinstance(value, ResponseError):
|
|
||||||
response.status = "failed"
|
|
||||||
response.error = value
|
|
||||||
break
|
|
||||||
|
|
||||||
yield value
|
|
||||||
|
|
||||||
if isinstance(value, ResponseErrorEvent):
|
|
||||||
return
|
|
||||||
|
|
||||||
if response.status == "incomplete":
|
|
||||||
yield ResponseIncompleteEvent(
|
|
||||||
response=response,
|
|
||||||
type="response.incomplete",
|
|
||||||
)
|
|
||||||
elif response.status == "failed":
|
|
||||||
yield ResponseFailedEvent(
|
|
||||||
response=response,
|
|
||||||
type="response.failed",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield ResponseCompletedEvent(
|
|
||||||
response=response,
|
|
||||||
type="response.completed",
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"openai.resources.responses.AsyncResponses.create",
|
|
||||||
AsyncMock(),
|
|
||||||
) as mock_create:
|
|
||||||
mock_create.side_effect = lambda **kwargs: mock_generator(
|
|
||||||
mock_create.return_value.pop(0), **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
yield mock_create
|
|
||||||
|
|
||||||
|
|
||||||
async def test_entity(
|
async def test_entity(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_config_entry: MockConfigEntry,
|
mock_config_entry: MockConfigEntry,
|
||||||
@@ -341,80 +242,6 @@ async def test_conversation_agent(
|
|||||||
assert agent.supported_languages == "*"
|
assert agent.supported_languages == "*"
|
||||||
|
|
||||||
|
|
||||||
def create_message_item(
|
|
||||||
id: str, text: str | list[str], output_index: int
|
|
||||||
) -> list[ResponseStreamEvent]:
|
|
||||||
"""Create a message item."""
|
|
||||||
if isinstance(text, str):
|
|
||||||
text = [text]
|
|
||||||
|
|
||||||
content = ResponseOutputText(annotations=[], text="", type="output_text")
|
|
||||||
events = [
|
|
||||||
ResponseOutputItemAddedEvent(
|
|
||||||
item=ResponseOutputMessage(
|
|
||||||
id=id,
|
|
||||||
content=[],
|
|
||||||
type="message",
|
|
||||||
role="assistant",
|
|
||||||
status="in_progress",
|
|
||||||
),
|
|
||||||
output_index=output_index,
|
|
||||||
type="response.output_item.added",
|
|
||||||
),
|
|
||||||
ResponseContentPartAddedEvent(
|
|
||||||
content_index=0,
|
|
||||||
item_id=id,
|
|
||||||
output_index=output_index,
|
|
||||||
part=content,
|
|
||||||
type="response.content_part.added",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
content.text = "".join(text)
|
|
||||||
events.extend(
|
|
||||||
ResponseTextDeltaEvent(
|
|
||||||
content_index=0,
|
|
||||||
delta=delta,
|
|
||||||
item_id=id,
|
|
||||||
output_index=output_index,
|
|
||||||
type="response.output_text.delta",
|
|
||||||
)
|
|
||||||
for delta in text
|
|
||||||
)
|
|
||||||
|
|
||||||
events.extend(
|
|
||||||
[
|
|
||||||
ResponseTextDoneEvent(
|
|
||||||
content_index=0,
|
|
||||||
item_id=id,
|
|
||||||
output_index=output_index,
|
|
||||||
text="".join(text),
|
|
||||||
type="response.output_text.done",
|
|
||||||
),
|
|
||||||
ResponseContentPartDoneEvent(
|
|
||||||
content_index=0,
|
|
||||||
item_id=id,
|
|
||||||
output_index=output_index,
|
|
||||||
part=content,
|
|
||||||
type="response.content_part.done",
|
|
||||||
),
|
|
||||||
ResponseOutputItemDoneEvent(
|
|
||||||
item=ResponseOutputMessage(
|
|
||||||
id=id,
|
|
||||||
content=[content],
|
|
||||||
role="assistant",
|
|
||||||
status="completed",
|
|
||||||
type="message",
|
|
||||||
),
|
|
||||||
output_index=output_index,
|
|
||||||
type="response.output_item.done",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return events
|
|
||||||
|
|
||||||
|
|
||||||
def create_function_tool_call_item(
|
def create_function_tool_call_item(
|
||||||
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
|
id: str, arguments: str | list[str], call_id: str, name: str, output_index: int
|
||||||
) -> list[ResponseStreamEvent]:
|
) -> list[ResponseStreamEvent]:
|
||||||
|
Reference in New Issue
Block a user