From 9797d391afacabba6cc11dfd92eaa609b5e8fad3 Mon Sep 17 00:00:00 2001 From: Denis Shulyaka Date: Wed, 20 Aug 2025 00:36:23 +0300 Subject: [PATCH] OpenAI external tools (#150599) --- .../components/openai_conversation/entity.py | 179 +++++++++++++----- .../openai_conversation/__init__.py | 5 +- .../snapshots/test_conversation.ambr | 68 +++++++ .../openai_conversation/test_conversation.py | 33 ++++ 4 files changed, 237 insertions(+), 48 deletions(-) diff --git a/homeassistant/components/openai_conversation/entity.py b/homeassistant/components/openai_conversation/entity.py index 64439c8a9c1..44d833c8e71 100644 --- a/homeassistant/components/openai_conversation/entity.py +++ b/homeassistant/components/openai_conversation/entity.py @@ -14,6 +14,7 @@ from openai._streaming import AsyncStream from openai.types.responses import ( EasyInputMessageParam, FunctionToolParam, + ResponseCodeInterpreterToolCall, ResponseCompletedEvent, ResponseErrorEvent, ResponseFailedEvent, @@ -21,6 +22,8 @@ from openai.types.responses import ( ResponseFunctionCallArgumentsDoneEvent, ResponseFunctionToolCall, ResponseFunctionToolCallParam, + ResponseFunctionWebSearch, + ResponseFunctionWebSearchParam, ResponseIncompleteEvent, ResponseInputFileParam, ResponseInputImageParam, @@ -149,16 +152,27 @@ def _convert_content_to_param( """Convert any native chat message for this agent to the native format.""" messages: ResponseInputParam = [] reasoning_summary: list[str] = [] + web_search_calls: dict[str, ResponseFunctionWebSearchParam] = {} for content in chat_content: if isinstance(content, conversation.ToolResultContent): - messages.append( - FunctionCallOutput( - type="function_call_output", - call_id=content.tool_call_id, - output=json.dumps(content.tool_result), + if ( + content.tool_name == "web_search_call" + and content.tool_call_id in web_search_calls + ): + web_search_call = web_search_calls.pop(content.tool_call_id) + web_search_call["status"] = content.tool_result.get( # type: ignore[typeddict-item] + "status", "completed" + ) + messages.append(web_search_call) + else: + messages.append( + FunctionCallOutput( + type="function_call_output", + call_id=content.tool_call_id, + output=json.dumps(content.tool_result), + ) ) - ) continue if content.content: @@ -173,15 +187,27 @@ def _convert_content_to_param( if isinstance(content, conversation.AssistantContent): if content.tool_calls: - messages.extend( - ResponseFunctionToolCallParam( - type="function_call", - name=tool_call.tool_name, - arguments=json.dumps(tool_call.tool_args), - call_id=tool_call.id, - ) - for tool_call in content.tool_calls - ) + for tool_call in content.tool_calls: + if ( + tool_call.external + and tool_call.tool_name == "web_search_call" + and "action" in tool_call.tool_args + ): + web_search_calls[tool_call.id] = ResponseFunctionWebSearchParam( + type="web_search_call", + id=tool_call.id, + action=tool_call.tool_args["action"], + status="completed", + ) + else: + messages.append( + ResponseFunctionToolCallParam( + type="function_call", + name=tool_call.tool_name, + arguments=json.dumps(tool_call.tool_args), + call_id=tool_call.id, + ) + ) if content.thinking_content: reasoning_summary.append(content.thinking_content) @@ -211,25 +237,37 @@ def _convert_content_to_param( async def _transform_stream( chat_log: conversation.ChatLog, stream: AsyncStream[ResponseStreamEvent], -) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: +) -> AsyncGenerator[ + conversation.AssistantContentDeltaDict | conversation.ToolResultContentDeltaDict +]: """Transform an OpenAI delta stream into HA format.""" last_summary_index = None + last_role: Literal["assistant", "tool_result"] | None = None async for event in stream: LOGGER.debug("Received event: %s", event) if isinstance(event, ResponseOutputItemAddedEvent): - if isinstance(event.item, ResponseOutputMessage): - yield {"role": event.item.role} - last_summary_index = None - elif isinstance(event.item, ResponseFunctionToolCall): + if isinstance(event.item, ResponseFunctionToolCall): # OpenAI has tool calls as individual events # while HA puts tool calls inside the assistant message. # We turn them into individual assistant content for HA # to ensure that tools are called as soon as possible. yield {"role": "assistant"} + last_role = "assistant" last_summary_index = None current_tool_call = event.item + elif ( + isinstance(event.item, ResponseOutputMessage) + or ( + isinstance(event.item, ResponseReasoningItem) + and last_summary_index is not None + ) # Subsequent ResponseReasoningItem + or last_role != "assistant" + ): + yield {"role": "assistant"} + last_role = "assistant" + last_summary_index = None elif isinstance(event, ResponseOutputItemDoneEvent): if isinstance(event.item, ResponseReasoningItem): yield { @@ -240,6 +278,52 @@ async def _transform_stream( encrypted_content=event.item.encrypted_content, ) } + last_summary_index = len(event.item.summary) - 1 + elif isinstance(event.item, ResponseCodeInterpreterToolCall): + yield { + "tool_calls": [ + llm.ToolInput( + id=event.item.id, + tool_name="code_interpreter", + tool_args={ + "code": event.item.code, + "container": event.item.container_id, + }, + external=True, + ) + ] + } + yield { + "role": "tool_result", + "tool_call_id": event.item.id, + "tool_name": "code_interpreter", + "tool_result": { + "output": [output.to_dict() for output in event.item.outputs] # type: ignore[misc] + if event.item.outputs is not None + else None + }, + } + last_role = "tool_result" + elif isinstance(event.item, ResponseFunctionWebSearch): + yield { + "tool_calls": [ + llm.ToolInput( + id=event.item.id, + tool_name="web_search_call", + tool_args={ + "action": event.item.action.to_dict(), + }, + external=True, + ) + ] + } + yield { + "role": "tool_result", + "tool_call_id": event.item.id, + "tool_name": "web_search_call", + "tool_result": {"status": event.item.status}, + } + last_role = "tool_result" elif isinstance(event, ResponseTextDeltaEvent): yield {"content": event.delta} elif isinstance(event, ResponseReasoningSummaryTextDeltaEvent): @@ -252,6 +336,7 @@ async def _transform_stream( and event.summary_index != last_summary_index ): yield {"role": "assistant"} + last_role = "assistant" last_summary_index = event.summary_index yield {"thinking_content": event.delta} elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent): @@ -348,6 +433,33 @@ class OpenAIBaseLLMEntity(Entity): """Generate an answer for the chat log.""" options = self.subentry.data + messages = _convert_content_to_param(chat_log.content) + + model_args = ResponseCreateParamsStreaming( + model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), + input=messages, + max_output_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), + top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), + temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), + user=chat_log.conversation_id, + store=False, + stream=True, + ) + + if model_args["model"].startswith(("o", "gpt-5")): + model_args["reasoning"] = { + "effort": options.get( + CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT + ), + "summary": "auto", + } + model_args["include"] = ["reasoning.encrypted_content"] + + if model_args["model"].startswith("gpt-5"): + model_args["text"] = { + "verbosity": options.get(CONF_VERBOSITY, RECOMMENDED_VERBOSITY) + } + tools: list[ToolParam] = [] if chat_log.llm_api: tools = [ @@ -381,36 +493,11 @@ class OpenAIBaseLLMEntity(Entity): ), ) ) + model_args.setdefault("include", []).append("code_interpreter_call.outputs") # type: ignore[union-attr] - messages = _convert_content_to_param(chat_log.content) - - model_args = ResponseCreateParamsStreaming( - model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), - input=messages, - max_output_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), - top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), - temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), - user=chat_log.conversation_id, - store=False, - stream=True, - ) if tools: model_args["tools"] = tools - if model_args["model"].startswith(("o", "gpt-5")): - model_args["reasoning"] = { - "effort": options.get( - CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT - ), - "summary": "auto", - } - model_args["include"] = ["reasoning.encrypted_content"] - - if model_args["model"].startswith("gpt-5"): - model_args["text"] = { - "verbosity": options.get(CONF_VERBOSITY, RECOMMENDED_VERBOSITY) - } - last_content = chat_log.content[-1] # Handle attachments by adding them to the last user message diff --git a/tests/components/openai_conversation/__init__.py b/tests/components/openai_conversation/__init__.py index 0ca02b8f629..e8effca3bc5 100644 --- a/tests/components/openai_conversation/__init__.py +++ b/tests/components/openai_conversation/__init__.py @@ -29,6 +29,7 @@ from openai.types.responses import ( ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallSearchingEvent, ) +from openai.types.responses.response_code_interpreter_tool_call import OutputLogs from openai.types.responses.response_function_web_search import ActionSearch from openai.types.responses.response_reasoning_item import Summary @@ -320,7 +321,7 @@ def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEve def create_code_interpreter_item( - id: str, code: str | list[str], output_index: int + id: str, code: str | list[str], output_index: int, logs: str | None = None ) -> list[ResponseStreamEvent]: """Create a message item.""" if isinstance(code, str): @@ -388,7 +389,7 @@ def create_code_interpreter_item( id=id, code=code, container_id=container_id, - outputs=None, + outputs=[OutputLogs(type="logs", logs=logs)] if logs else None, status="completed", type="code_interpreter_call", ), diff --git a/tests/components/openai_conversation/snapshots/test_conversation.ambr b/tests/components/openai_conversation/snapshots/test_conversation.ambr index d33d62214ef..473d32a53f8 100644 --- a/tests/components/openai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/openai_conversation/snapshots/test_conversation.ambr @@ -1,4 +1,39 @@ # serializer version: 1 +# name: test_code_interpreter + list([ + dict({ + 'content': 'Please use the python tool to calculate square root of 55555', + 'role': 'user', + 'type': 'message', + }), + dict({ + 'arguments': '{"code": "import math\\nmath.sqrt(55555)", "container": "cntr_A"}', + 'call_id': 'ci_A', + 'name': 'code_interpreter', + 'type': 'function_call', + }), + dict({ + 'call_id': 'ci_A', + 'output': '{"output": [{"logs": "235.70108188126758\\n", "type": "logs"}]}', + 'type': 'function_call_output', + }), + dict({ + 'content': 'I’ve calculated it with Python: the square root of 55555 is approximately 235.70108188126758.', + 'role': 'assistant', + 'type': 'message', + }), + dict({ + 'content': 'Thank you!', + 'role': 'user', + 'type': 'message', + }), + dict({ + 'content': 'You are welcome!', + 'role': 'assistant', + 'type': 'message', + }), + ]) +# --- # name: test_function_call list([ dict({ @@ -172,3 +207,36 @@ }), ]) # --- +# name: test_web_search + list([ + dict({ + 'content': "What's on the latest news?", + 'role': 'user', + 'type': 'message', + }), + dict({ + 'action': dict({ + 'query': 'query', + 'type': 'search', + }), + 'id': 'ws_A', + 'status': 'completed', + 'type': 'web_search_call', + }), + dict({ + 'content': 'Home Assistant now supports ChatGPT Search in Assist', + 'role': 'assistant', + 'type': 'message', + }), + dict({ + 'content': 'Thank you!', + 'role': 'user', + 'type': 'message', + }), + dict({ + 'content': 'You are welcome!', + 'role': 'assistant', + 'type': 'message', + }), + ]) +# --- diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 921eb39c542..452404f65ac 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -435,6 +435,7 @@ async def test_web_search( mock_init_component, mock_create_stream, mock_chat_log: MockChatLog, # noqa: F811 + snapshot: SnapshotAssertion, ) -> None: """Test web_search_tool.""" subentry = next(iter(mock_config_entry.subentries.values())) @@ -487,6 +488,21 @@ async def test_web_search( assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.speech["plain"]["speech"] == message, result.response.speech + # Test follow-up message in multi-turn conversation + mock_create_stream.return_value = [ + (*create_message_item(id="msg_B", text="You are welcome!", output_index=1),) + ] + + result = await conversation.async_converse( + hass, + "Thank you!", + mock_chat_log.conversation_id, + Context(), + agent_id="conversation.openai_conversation", + ) + + assert mock_create_stream.mock_calls[1][2]["input"][1:] == snapshot + async def test_code_interpreter( hass: HomeAssistant, @@ -494,6 +510,7 @@ async def test_code_interpreter( mock_init_component, mock_create_stream, mock_chat_log: MockChatLog, # noqa: F811 + snapshot: SnapshotAssertion, ) -> None: """Test code_interpreter tool.""" subentry = next(iter(mock_config_entry.subentries.values())) @@ -513,6 +530,7 @@ async def test_code_interpreter( *create_code_interpreter_item( id="ci_A", code=["import", " math", "\n", "math", ".sqrt", "(", "555", "55", ")"], + logs="235.70108188126758\n", output_index=0, ), *create_message_item(id="msg_A", text=message, output_index=1), @@ -532,3 +550,18 @@ async def test_code_interpreter( ] assert result.response.response_type == intent.IntentResponseType.ACTION_DONE assert result.response.speech["plain"]["speech"] == message, result.response.speech + + # Test follow-up message in multi-turn conversation + mock_create_stream.return_value = [ + (*create_message_item(id="msg_B", text="You are welcome!", output_index=1),) + ] + + result = await conversation.async_converse( + hass, + "Thank you!", + mock_chat_log.conversation_id, + Context(), + agent_id="conversation.openai_conversation", + ) + + assert mock_create_stream.mock_calls[1][2]["input"][1:] == snapshot