OpenAI external tools (#150599)

This commit is contained in:
Denis Shulyaka
2025-08-20 00:36:23 +03:00
committed by GitHub
parent e68df66028
commit 9797d391af
4 changed files with 237 additions and 48 deletions

View File

@@ -14,6 +14,7 @@ from openai._streaming import AsyncStream
from openai.types.responses import (
EasyInputMessageParam,
FunctionToolParam,
ResponseCodeInterpreterToolCall,
ResponseCompletedEvent,
ResponseErrorEvent,
ResponseFailedEvent,
@@ -21,6 +22,8 @@ from openai.types.responses import (
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionToolCall,
ResponseFunctionToolCallParam,
ResponseFunctionWebSearch,
ResponseFunctionWebSearchParam,
ResponseIncompleteEvent,
ResponseInputFileParam,
ResponseInputImageParam,
@@ -149,16 +152,27 @@ def _convert_content_to_param(
"""Convert any native chat message for this agent to the native format."""
messages: ResponseInputParam = []
reasoning_summary: list[str] = []
web_search_calls: dict[str, ResponseFunctionWebSearchParam] = {}
for content in chat_content:
if isinstance(content, conversation.ToolResultContent):
messages.append(
FunctionCallOutput(
type="function_call_output",
call_id=content.tool_call_id,
output=json.dumps(content.tool_result),
if (
content.tool_name == "web_search_call"
and content.tool_call_id in web_search_calls
):
web_search_call = web_search_calls.pop(content.tool_call_id)
web_search_call["status"] = content.tool_result.get( # type: ignore[typeddict-item]
"status", "completed"
)
messages.append(web_search_call)
else:
messages.append(
FunctionCallOutput(
type="function_call_output",
call_id=content.tool_call_id,
output=json.dumps(content.tool_result),
)
)
)
continue
if content.content:
@@ -173,15 +187,27 @@ def _convert_content_to_param(
if isinstance(content, conversation.AssistantContent):
if content.tool_calls:
messages.extend(
ResponseFunctionToolCallParam(
type="function_call",
name=tool_call.tool_name,
arguments=json.dumps(tool_call.tool_args),
call_id=tool_call.id,
)
for tool_call in content.tool_calls
)
for tool_call in content.tool_calls:
if (
tool_call.external
and tool_call.tool_name == "web_search_call"
and "action" in tool_call.tool_args
):
web_search_calls[tool_call.id] = ResponseFunctionWebSearchParam(
type="web_search_call",
id=tool_call.id,
action=tool_call.tool_args["action"],
status="completed",
)
else:
messages.append(
ResponseFunctionToolCallParam(
type="function_call",
name=tool_call.tool_name,
arguments=json.dumps(tool_call.tool_args),
call_id=tool_call.id,
)
)
if content.thinking_content:
reasoning_summary.append(content.thinking_content)
@@ -211,25 +237,37 @@ def _convert_content_to_param(
async def _transform_stream(
chat_log: conversation.ChatLog,
stream: AsyncStream[ResponseStreamEvent],
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
) -> AsyncGenerator[
conversation.AssistantContentDeltaDict | conversation.ToolResultContentDeltaDict
]:
"""Transform an OpenAI delta stream into HA format."""
last_summary_index = None
last_role: Literal["assistant", "tool_result"] | None = None
async for event in stream:
LOGGER.debug("Received event: %s", event)
if isinstance(event, ResponseOutputItemAddedEvent):
if isinstance(event.item, ResponseOutputMessage):
yield {"role": event.item.role}
last_summary_index = None
elif isinstance(event.item, ResponseFunctionToolCall):
if isinstance(event.item, ResponseFunctionToolCall):
# OpenAI has tool calls as individual events
# while HA puts tool calls inside the assistant message.
# We turn them into individual assistant content for HA
# to ensure that tools are called as soon as possible.
yield {"role": "assistant"}
last_role = "assistant"
last_summary_index = None
current_tool_call = event.item
elif (
isinstance(event.item, ResponseOutputMessage)
or (
isinstance(event.item, ResponseReasoningItem)
and last_summary_index is not None
) # Subsequent ResponseReasoningItem
or last_role != "assistant"
):
yield {"role": "assistant"}
last_role = "assistant"
last_summary_index = None
elif isinstance(event, ResponseOutputItemDoneEvent):
if isinstance(event.item, ResponseReasoningItem):
yield {
@@ -240,6 +278,52 @@ async def _transform_stream(
encrypted_content=event.item.encrypted_content,
)
}
last_summary_index = len(event.item.summary) - 1
elif isinstance(event.item, ResponseCodeInterpreterToolCall):
yield {
"tool_calls": [
llm.ToolInput(
id=event.item.id,
tool_name="code_interpreter",
tool_args={
"code": event.item.code,
"container": event.item.container_id,
},
external=True,
)
]
}
yield {
"role": "tool_result",
"tool_call_id": event.item.id,
"tool_name": "code_interpreter",
"tool_result": {
"output": [output.to_dict() for output in event.item.outputs] # type: ignore[misc]
if event.item.outputs is not None
else None
},
}
last_role = "tool_result"
elif isinstance(event.item, ResponseFunctionWebSearch):
yield {
"tool_calls": [
llm.ToolInput(
id=event.item.id,
tool_name="web_search_call",
tool_args={
"action": event.item.action.to_dict(),
},
external=True,
)
]
}
yield {
"role": "tool_result",
"tool_call_id": event.item.id,
"tool_name": "web_search_call",
"tool_result": {"status": event.item.status},
}
last_role = "tool_result"
elif isinstance(event, ResponseTextDeltaEvent):
yield {"content": event.delta}
elif isinstance(event, ResponseReasoningSummaryTextDeltaEvent):
@@ -252,6 +336,7 @@ async def _transform_stream(
and event.summary_index != last_summary_index
):
yield {"role": "assistant"}
last_role = "assistant"
last_summary_index = event.summary_index
yield {"thinking_content": event.delta}
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
@@ -348,6 +433,33 @@ class OpenAIBaseLLMEntity(Entity):
"""Generate an answer for the chat log."""
options = self.subentry.data
messages = _convert_content_to_param(chat_log.content)
model_args = ResponseCreateParamsStreaming(
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
input=messages,
max_output_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
user=chat_log.conversation_id,
store=False,
stream=True,
)
if model_args["model"].startswith(("o", "gpt-5")):
model_args["reasoning"] = {
"effort": options.get(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
),
"summary": "auto",
}
model_args["include"] = ["reasoning.encrypted_content"]
if model_args["model"].startswith("gpt-5"):
model_args["text"] = {
"verbosity": options.get(CONF_VERBOSITY, RECOMMENDED_VERBOSITY)
}
tools: list[ToolParam] = []
if chat_log.llm_api:
tools = [
@@ -381,36 +493,11 @@ class OpenAIBaseLLMEntity(Entity):
),
)
)
model_args.setdefault("include", []).append("code_interpreter_call.outputs") # type: ignore[union-attr]
messages = _convert_content_to_param(chat_log.content)
model_args = ResponseCreateParamsStreaming(
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
input=messages,
max_output_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
user=chat_log.conversation_id,
store=False,
stream=True,
)
if tools:
model_args["tools"] = tools
if model_args["model"].startswith(("o", "gpt-5")):
model_args["reasoning"] = {
"effort": options.get(
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
),
"summary": "auto",
}
model_args["include"] = ["reasoning.encrypted_content"]
if model_args["model"].startswith("gpt-5"):
model_args["text"] = {
"verbosity": options.get(CONF_VERBOSITY, RECOMMENDED_VERBOSITY)
}
last_content = chat_log.content[-1]
# Handle attachments by adding them to the last user message

View File

@@ -29,6 +29,7 @@ from openai.types.responses import (
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
)
from openai.types.responses.response_code_interpreter_tool_call import OutputLogs
from openai.types.responses.response_function_web_search import ActionSearch
from openai.types.responses.response_reasoning_item import Summary
@@ -320,7 +321,7 @@ def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEve
def create_code_interpreter_item(
id: str, code: str | list[str], output_index: int
id: str, code: str | list[str], output_index: int, logs: str | None = None
) -> list[ResponseStreamEvent]:
"""Create a message item."""
if isinstance(code, str):
@@ -388,7 +389,7 @@ def create_code_interpreter_item(
id=id,
code=code,
container_id=container_id,
outputs=None,
outputs=[OutputLogs(type="logs", logs=logs)] if logs else None,
status="completed",
type="code_interpreter_call",
),

View File

@@ -1,4 +1,39 @@
# serializer version: 1
# name: test_code_interpreter
list([
dict({
'content': 'Please use the python tool to calculate square root of 55555',
'role': 'user',
'type': 'message',
}),
dict({
'arguments': '{"code": "import math\\nmath.sqrt(55555)", "container": "cntr_A"}',
'call_id': 'ci_A',
'name': 'code_interpreter',
'type': 'function_call',
}),
dict({
'call_id': 'ci_A',
'output': '{"output": [{"logs": "235.70108188126758\\n", "type": "logs"}]}',
'type': 'function_call_output',
}),
dict({
'content': 'Ive calculated it with Python: the square root of 55555 is approximately 235.70108188126758.',
'role': 'assistant',
'type': 'message',
}),
dict({
'content': 'Thank you!',
'role': 'user',
'type': 'message',
}),
dict({
'content': 'You are welcome!',
'role': 'assistant',
'type': 'message',
}),
])
# ---
# name: test_function_call
list([
dict({
@@ -172,3 +207,36 @@
}),
])
# ---
# name: test_web_search
list([
dict({
'content': "What's on the latest news?",
'role': 'user',
'type': 'message',
}),
dict({
'action': dict({
'query': 'query',
'type': 'search',
}),
'id': 'ws_A',
'status': 'completed',
'type': 'web_search_call',
}),
dict({
'content': 'Home Assistant now supports ChatGPT Search in Assist',
'role': 'assistant',
'type': 'message',
}),
dict({
'content': 'Thank you!',
'role': 'user',
'type': 'message',
}),
dict({
'content': 'You are welcome!',
'role': 'assistant',
'type': 'message',
}),
])
# ---

View File

@@ -435,6 +435,7 @@ async def test_web_search(
mock_init_component,
mock_create_stream,
mock_chat_log: MockChatLog, # noqa: F811
snapshot: SnapshotAssertion,
) -> None:
"""Test web_search_tool."""
subentry = next(iter(mock_config_entry.subentries.values()))
@@ -487,6 +488,21 @@ async def test_web_search(
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.speech["plain"]["speech"] == message, result.response.speech
# Test follow-up message in multi-turn conversation
mock_create_stream.return_value = [
(*create_message_item(id="msg_B", text="You are welcome!", output_index=1),)
]
result = await conversation.async_converse(
hass,
"Thank you!",
mock_chat_log.conversation_id,
Context(),
agent_id="conversation.openai_conversation",
)
assert mock_create_stream.mock_calls[1][2]["input"][1:] == snapshot
async def test_code_interpreter(
hass: HomeAssistant,
@@ -494,6 +510,7 @@ async def test_code_interpreter(
mock_init_component,
mock_create_stream,
mock_chat_log: MockChatLog, # noqa: F811
snapshot: SnapshotAssertion,
) -> None:
"""Test code_interpreter tool."""
subentry = next(iter(mock_config_entry.subentries.values()))
@@ -513,6 +530,7 @@ async def test_code_interpreter(
*create_code_interpreter_item(
id="ci_A",
code=["import", " math", "\n", "math", ".sqrt", "(", "555", "55", ")"],
logs="235.70108188126758\n",
output_index=0,
),
*create_message_item(id="msg_A", text=message, output_index=1),
@@ -532,3 +550,18 @@ async def test_code_interpreter(
]
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
assert result.response.speech["plain"]["speech"] == message, result.response.speech
# Test follow-up message in multi-turn conversation
mock_create_stream.return_value = [
(*create_message_item(id="msg_B", text="You are welcome!", output_index=1),)
]
result = await conversation.async_converse(
hass,
"Thank you!",
mock_chat_log.conversation_id,
Context(),
agent_id="conversation.openai_conversation",
)
assert mock_create_stream.mock_calls[1][2]["input"][1:] == snapshot