mirror of
https://github.com/home-assistant/core.git
synced 2025-08-31 02:11:32 +02:00
OpenAI external tools (#150599)
This commit is contained in:
@@ -14,6 +14,7 @@ from openai._streaming import AsyncStream
|
||||
from openai.types.responses import (
|
||||
EasyInputMessageParam,
|
||||
FunctionToolParam,
|
||||
ResponseCodeInterpreterToolCall,
|
||||
ResponseCompletedEvent,
|
||||
ResponseErrorEvent,
|
||||
ResponseFailedEvent,
|
||||
@@ -21,6 +22,8 @@ from openai.types.responses import (
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseFunctionToolCallParam,
|
||||
ResponseFunctionWebSearch,
|
||||
ResponseFunctionWebSearchParam,
|
||||
ResponseIncompleteEvent,
|
||||
ResponseInputFileParam,
|
||||
ResponseInputImageParam,
|
||||
@@ -149,16 +152,27 @@ def _convert_content_to_param(
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
messages: ResponseInputParam = []
|
||||
reasoning_summary: list[str] = []
|
||||
web_search_calls: dict[str, ResponseFunctionWebSearchParam] = {}
|
||||
|
||||
for content in chat_content:
|
||||
if isinstance(content, conversation.ToolResultContent):
|
||||
messages.append(
|
||||
FunctionCallOutput(
|
||||
type="function_call_output",
|
||||
call_id=content.tool_call_id,
|
||||
output=json.dumps(content.tool_result),
|
||||
if (
|
||||
content.tool_name == "web_search_call"
|
||||
and content.tool_call_id in web_search_calls
|
||||
):
|
||||
web_search_call = web_search_calls.pop(content.tool_call_id)
|
||||
web_search_call["status"] = content.tool_result.get( # type: ignore[typeddict-item]
|
||||
"status", "completed"
|
||||
)
|
||||
messages.append(web_search_call)
|
||||
else:
|
||||
messages.append(
|
||||
FunctionCallOutput(
|
||||
type="function_call_output",
|
||||
call_id=content.tool_call_id,
|
||||
output=json.dumps(content.tool_result),
|
||||
)
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if content.content:
|
||||
@@ -173,15 +187,27 @@ def _convert_content_to_param(
|
||||
|
||||
if isinstance(content, conversation.AssistantContent):
|
||||
if content.tool_calls:
|
||||
messages.extend(
|
||||
ResponseFunctionToolCallParam(
|
||||
type="function_call",
|
||||
name=tool_call.tool_name,
|
||||
arguments=json.dumps(tool_call.tool_args),
|
||||
call_id=tool_call.id,
|
||||
)
|
||||
for tool_call in content.tool_calls
|
||||
)
|
||||
for tool_call in content.tool_calls:
|
||||
if (
|
||||
tool_call.external
|
||||
and tool_call.tool_name == "web_search_call"
|
||||
and "action" in tool_call.tool_args
|
||||
):
|
||||
web_search_calls[tool_call.id] = ResponseFunctionWebSearchParam(
|
||||
type="web_search_call",
|
||||
id=tool_call.id,
|
||||
action=tool_call.tool_args["action"],
|
||||
status="completed",
|
||||
)
|
||||
else:
|
||||
messages.append(
|
||||
ResponseFunctionToolCallParam(
|
||||
type="function_call",
|
||||
name=tool_call.tool_name,
|
||||
arguments=json.dumps(tool_call.tool_args),
|
||||
call_id=tool_call.id,
|
||||
)
|
||||
)
|
||||
|
||||
if content.thinking_content:
|
||||
reasoning_summary.append(content.thinking_content)
|
||||
@@ -211,25 +237,37 @@ def _convert_content_to_param(
|
||||
async def _transform_stream(
|
||||
chat_log: conversation.ChatLog,
|
||||
stream: AsyncStream[ResponseStreamEvent],
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
) -> AsyncGenerator[
|
||||
conversation.AssistantContentDeltaDict | conversation.ToolResultContentDeltaDict
|
||||
]:
|
||||
"""Transform an OpenAI delta stream into HA format."""
|
||||
last_summary_index = None
|
||||
last_role: Literal["assistant", "tool_result"] | None = None
|
||||
|
||||
async for event in stream:
|
||||
LOGGER.debug("Received event: %s", event)
|
||||
|
||||
if isinstance(event, ResponseOutputItemAddedEvent):
|
||||
if isinstance(event.item, ResponseOutputMessage):
|
||||
yield {"role": event.item.role}
|
||||
last_summary_index = None
|
||||
elif isinstance(event.item, ResponseFunctionToolCall):
|
||||
if isinstance(event.item, ResponseFunctionToolCall):
|
||||
# OpenAI has tool calls as individual events
|
||||
# while HA puts tool calls inside the assistant message.
|
||||
# We turn them into individual assistant content for HA
|
||||
# to ensure that tools are called as soon as possible.
|
||||
yield {"role": "assistant"}
|
||||
last_role = "assistant"
|
||||
last_summary_index = None
|
||||
current_tool_call = event.item
|
||||
elif (
|
||||
isinstance(event.item, ResponseOutputMessage)
|
||||
or (
|
||||
isinstance(event.item, ResponseReasoningItem)
|
||||
and last_summary_index is not None
|
||||
) # Subsequent ResponseReasoningItem
|
||||
or last_role != "assistant"
|
||||
):
|
||||
yield {"role": "assistant"}
|
||||
last_role = "assistant"
|
||||
last_summary_index = None
|
||||
elif isinstance(event, ResponseOutputItemDoneEvent):
|
||||
if isinstance(event.item, ResponseReasoningItem):
|
||||
yield {
|
||||
@@ -240,6 +278,52 @@ async def _transform_stream(
|
||||
encrypted_content=event.item.encrypted_content,
|
||||
)
|
||||
}
|
||||
last_summary_index = len(event.item.summary) - 1
|
||||
elif isinstance(event.item, ResponseCodeInterpreterToolCall):
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=event.item.id,
|
||||
tool_name="code_interpreter",
|
||||
tool_args={
|
||||
"code": event.item.code,
|
||||
"container": event.item.container_id,
|
||||
},
|
||||
external=True,
|
||||
)
|
||||
]
|
||||
}
|
||||
yield {
|
||||
"role": "tool_result",
|
||||
"tool_call_id": event.item.id,
|
||||
"tool_name": "code_interpreter",
|
||||
"tool_result": {
|
||||
"output": [output.to_dict() for output in event.item.outputs] # type: ignore[misc]
|
||||
if event.item.outputs is not None
|
||||
else None
|
||||
},
|
||||
}
|
||||
last_role = "tool_result"
|
||||
elif isinstance(event.item, ResponseFunctionWebSearch):
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=event.item.id,
|
||||
tool_name="web_search_call",
|
||||
tool_args={
|
||||
"action": event.item.action.to_dict(),
|
||||
},
|
||||
external=True,
|
||||
)
|
||||
]
|
||||
}
|
||||
yield {
|
||||
"role": "tool_result",
|
||||
"tool_call_id": event.item.id,
|
||||
"tool_name": "web_search_call",
|
||||
"tool_result": {"status": event.item.status},
|
||||
}
|
||||
last_role = "tool_result"
|
||||
elif isinstance(event, ResponseTextDeltaEvent):
|
||||
yield {"content": event.delta}
|
||||
elif isinstance(event, ResponseReasoningSummaryTextDeltaEvent):
|
||||
@@ -252,6 +336,7 @@ async def _transform_stream(
|
||||
and event.summary_index != last_summary_index
|
||||
):
|
||||
yield {"role": "assistant"}
|
||||
last_role = "assistant"
|
||||
last_summary_index = event.summary_index
|
||||
yield {"thinking_content": event.delta}
|
||||
elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent):
|
||||
@@ -348,6 +433,33 @@ class OpenAIBaseLLMEntity(Entity):
|
||||
"""Generate an answer for the chat log."""
|
||||
options = self.subentry.data
|
||||
|
||||
messages = _convert_content_to_param(chat_log.content)
|
||||
|
||||
model_args = ResponseCreateParamsStreaming(
|
||||
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||
input=messages,
|
||||
max_output_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
user=chat_log.conversation_id,
|
||||
store=False,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if model_args["model"].startswith(("o", "gpt-5")):
|
||||
model_args["reasoning"] = {
|
||||
"effort": options.get(
|
||||
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
||||
),
|
||||
"summary": "auto",
|
||||
}
|
||||
model_args["include"] = ["reasoning.encrypted_content"]
|
||||
|
||||
if model_args["model"].startswith("gpt-5"):
|
||||
model_args["text"] = {
|
||||
"verbosity": options.get(CONF_VERBOSITY, RECOMMENDED_VERBOSITY)
|
||||
}
|
||||
|
||||
tools: list[ToolParam] = []
|
||||
if chat_log.llm_api:
|
||||
tools = [
|
||||
@@ -381,36 +493,11 @@ class OpenAIBaseLLMEntity(Entity):
|
||||
),
|
||||
)
|
||||
)
|
||||
model_args.setdefault("include", []).append("code_interpreter_call.outputs") # type: ignore[union-attr]
|
||||
|
||||
messages = _convert_content_to_param(chat_log.content)
|
||||
|
||||
model_args = ResponseCreateParamsStreaming(
|
||||
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
|
||||
input=messages,
|
||||
max_output_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
|
||||
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
|
||||
user=chat_log.conversation_id,
|
||||
store=False,
|
||||
stream=True,
|
||||
)
|
||||
if tools:
|
||||
model_args["tools"] = tools
|
||||
|
||||
if model_args["model"].startswith(("o", "gpt-5")):
|
||||
model_args["reasoning"] = {
|
||||
"effort": options.get(
|
||||
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
|
||||
),
|
||||
"summary": "auto",
|
||||
}
|
||||
model_args["include"] = ["reasoning.encrypted_content"]
|
||||
|
||||
if model_args["model"].startswith("gpt-5"):
|
||||
model_args["text"] = {
|
||||
"verbosity": options.get(CONF_VERBOSITY, RECOMMENDED_VERBOSITY)
|
||||
}
|
||||
|
||||
last_content = chat_log.content[-1]
|
||||
|
||||
# Handle attachments by adding them to the last user message
|
||||
|
@@ -29,6 +29,7 @@ from openai.types.responses import (
|
||||
ResponseWebSearchCallInProgressEvent,
|
||||
ResponseWebSearchCallSearchingEvent,
|
||||
)
|
||||
from openai.types.responses.response_code_interpreter_tool_call import OutputLogs
|
||||
from openai.types.responses.response_function_web_search import ActionSearch
|
||||
from openai.types.responses.response_reasoning_item import Summary
|
||||
|
||||
@@ -320,7 +321,7 @@ def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEve
|
||||
|
||||
|
||||
def create_code_interpreter_item(
|
||||
id: str, code: str | list[str], output_index: int
|
||||
id: str, code: str | list[str], output_index: int, logs: str | None = None
|
||||
) -> list[ResponseStreamEvent]:
|
||||
"""Create a message item."""
|
||||
if isinstance(code, str):
|
||||
@@ -388,7 +389,7 @@ def create_code_interpreter_item(
|
||||
id=id,
|
||||
code=code,
|
||||
container_id=container_id,
|
||||
outputs=None,
|
||||
outputs=[OutputLogs(type="logs", logs=logs)] if logs else None,
|
||||
status="completed",
|
||||
type="code_interpreter_call",
|
||||
),
|
||||
|
@@ -1,4 +1,39 @@
|
||||
# serializer version: 1
|
||||
# name: test_code_interpreter
|
||||
list([
|
||||
dict({
|
||||
'content': 'Please use the python tool to calculate square root of 55555',
|
||||
'role': 'user',
|
||||
'type': 'message',
|
||||
}),
|
||||
dict({
|
||||
'arguments': '{"code": "import math\\nmath.sqrt(55555)", "container": "cntr_A"}',
|
||||
'call_id': 'ci_A',
|
||||
'name': 'code_interpreter',
|
||||
'type': 'function_call',
|
||||
}),
|
||||
dict({
|
||||
'call_id': 'ci_A',
|
||||
'output': '{"output": [{"logs": "235.70108188126758\\n", "type": "logs"}]}',
|
||||
'type': 'function_call_output',
|
||||
}),
|
||||
dict({
|
||||
'content': 'I’ve calculated it with Python: the square root of 55555 is approximately 235.70108188126758.',
|
||||
'role': 'assistant',
|
||||
'type': 'message',
|
||||
}),
|
||||
dict({
|
||||
'content': 'Thank you!',
|
||||
'role': 'user',
|
||||
'type': 'message',
|
||||
}),
|
||||
dict({
|
||||
'content': 'You are welcome!',
|
||||
'role': 'assistant',
|
||||
'type': 'message',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_function_call
|
||||
list([
|
||||
dict({
|
||||
@@ -172,3 +207,36 @@
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_web_search
|
||||
list([
|
||||
dict({
|
||||
'content': "What's on the latest news?",
|
||||
'role': 'user',
|
||||
'type': 'message',
|
||||
}),
|
||||
dict({
|
||||
'action': dict({
|
||||
'query': 'query',
|
||||
'type': 'search',
|
||||
}),
|
||||
'id': 'ws_A',
|
||||
'status': 'completed',
|
||||
'type': 'web_search_call',
|
||||
}),
|
||||
dict({
|
||||
'content': 'Home Assistant now supports ChatGPT Search in Assist',
|
||||
'role': 'assistant',
|
||||
'type': 'message',
|
||||
}),
|
||||
dict({
|
||||
'content': 'Thank you!',
|
||||
'role': 'user',
|
||||
'type': 'message',
|
||||
}),
|
||||
dict({
|
||||
'content': 'You are welcome!',
|
||||
'role': 'assistant',
|
||||
'type': 'message',
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
|
@@ -435,6 +435,7 @@ async def test_web_search(
|
||||
mock_init_component,
|
||||
mock_create_stream,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test web_search_tool."""
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
@@ -487,6 +488,21 @@ async def test_web_search(
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|
||||
|
||||
# Test follow-up message in multi-turn conversation
|
||||
mock_create_stream.return_value = [
|
||||
(*create_message_item(id="msg_B", text="You are welcome!", output_index=1),)
|
||||
]
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Thank you!",
|
||||
mock_chat_log.conversation_id,
|
||||
Context(),
|
||||
agent_id="conversation.openai_conversation",
|
||||
)
|
||||
|
||||
assert mock_create_stream.mock_calls[1][2]["input"][1:] == snapshot
|
||||
|
||||
|
||||
async def test_code_interpreter(
|
||||
hass: HomeAssistant,
|
||||
@@ -494,6 +510,7 @@ async def test_code_interpreter(
|
||||
mock_init_component,
|
||||
mock_create_stream,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test code_interpreter tool."""
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
@@ -513,6 +530,7 @@ async def test_code_interpreter(
|
||||
*create_code_interpreter_item(
|
||||
id="ci_A",
|
||||
code=["import", " math", "\n", "math", ".sqrt", "(", "555", "55", ")"],
|
||||
logs="235.70108188126758\n",
|
||||
output_index=0,
|
||||
),
|
||||
*create_message_item(id="msg_A", text=message, output_index=1),
|
||||
@@ -532,3 +550,18 @@ async def test_code_interpreter(
|
||||
]
|
||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
assert result.response.speech["plain"]["speech"] == message, result.response.speech
|
||||
|
||||
# Test follow-up message in multi-turn conversation
|
||||
mock_create_stream.return_value = [
|
||||
(*create_message_item(id="msg_B", text="You are welcome!", output_index=1),)
|
||||
]
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"Thank you!",
|
||||
mock_chat_log.conversation_id,
|
||||
Context(),
|
||||
agent_id="conversation.openai_conversation",
|
||||
)
|
||||
|
||||
assert mock_create_stream.mock_calls[1][2]["input"][1:] == snapshot
|
||||
|
Reference in New Issue
Block a user