diff --git a/homeassistant/components/openai_conversation/entity.py b/homeassistant/components/openai_conversation/entity.py index 748c0c8f874..9c1e77be7d3 100644 --- a/homeassistant/components/openai_conversation/entity.py +++ b/homeassistant/components/openai_conversation/entity.py @@ -3,11 +3,11 @@ from __future__ import annotations import base64 -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator, Callable, Iterable import json from mimetypes import guess_file_type from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal import openai from openai._streaming import AsyncStream @@ -29,14 +29,15 @@ from openai.types.responses import ( ResponseOutputItemAddedEvent, ResponseOutputItemDoneEvent, ResponseOutputMessage, - ResponseOutputMessageParam, ResponseReasoningItem, ResponseReasoningItemParam, + ResponseReasoningSummaryTextDeltaEvent, ResponseStreamEvent, ResponseTextDeltaEvent, ToolParam, WebSearchToolParam, ) +from openai.types.responses.response_create_params import ResponseCreateParamsStreaming from openai.types.responses.response_input_param import FunctionCallOutput from openai.types.responses.tool_param import ( CodeInterpreter, @@ -143,70 +144,116 @@ def _format_tool( def _convert_content_to_param( - content: conversation.Content, + chat_content: Iterable[conversation.Content], ) -> ResponseInputParam: """Convert any native chat message for this agent to the native format.""" messages: ResponseInputParam = [] - if isinstance(content, conversation.ToolResultContent): - return [ - FunctionCallOutput( - type="function_call_output", - call_id=content.tool_call_id, - output=json.dumps(content.tool_result), - ) - ] + reasoning_summary: list[str] = [] - if content.content: - role: Literal["user", "assistant", "system", "developer"] = content.role - if role == "system": - role = "developer" - messages.append( - EasyInputMessageParam(type="message", role=role, content=content.content) - ) - - if isinstance(content, conversation.AssistantContent) and content.tool_calls: - messages.extend( - ResponseFunctionToolCallParam( - type="function_call", - name=tool_call.tool_name, - arguments=json.dumps(tool_call.tool_args), - call_id=tool_call.id, + for content in chat_content: + if isinstance(content, conversation.ToolResultContent): + messages.append( + FunctionCallOutput( + type="function_call_output", + call_id=content.tool_call_id, + output=json.dumps(content.tool_result), + ) ) - for tool_call in content.tool_calls - ) + continue + + if content.content: + role: Literal["user", "assistant", "system", "developer"] = content.role + if role == "system": + role = "developer" + messages.append( + EasyInputMessageParam( + type="message", role=role, content=content.content + ) + ) + + if isinstance(content, conversation.AssistantContent): + if content.tool_calls: + messages.extend( + ResponseFunctionToolCallParam( + type="function_call", + name=tool_call.tool_name, + arguments=json.dumps(tool_call.tool_args), + call_id=tool_call.id, + ) + for tool_call in content.tool_calls + ) + + if content.thinking_content: + reasoning_summary.append(content.thinking_content) + + if isinstance(content.native, ResponseReasoningItem): + messages.append( + ResponseReasoningItemParam( + type="reasoning", + id=content.native.id, + summary=[ + { + "type": "summary_text", + "text": summary, + } + for summary in reasoning_summary + ] + if content.thinking_content + else [], + encrypted_content=content.native.encrypted_content, + ) + ) + reasoning_summary = [] + return messages async def _transform_stream( chat_log: conversation.ChatLog, - result: AsyncStream[ResponseStreamEvent], - messages: ResponseInputParam, + stream: AsyncStream[ResponseStreamEvent], ) -> AsyncGenerator[conversation.AssistantContentDeltaDict]: """Transform an OpenAI delta stream into HA format.""" - async for event in result: + last_summary_index = None + + async for event in stream: LOGGER.debug("Received event: %s", event) if isinstance(event, ResponseOutputItemAddedEvent): if isinstance(event.item, ResponseOutputMessage): yield {"role": event.item.role} + last_summary_index = None elif isinstance(event.item, ResponseFunctionToolCall): # OpenAI has tool calls as individual events # while HA puts tool calls inside the assistant message. # We turn them into individual assistant content for HA # to ensure that tools are called as soon as possible. yield {"role": "assistant"} + last_summary_index = None current_tool_call = event.item elif isinstance(event, ResponseOutputItemDoneEvent): - item = event.item.model_dump() - item.pop("status", None) if isinstance(event.item, ResponseReasoningItem): - messages.append(cast(ResponseReasoningItemParam, item)) - elif isinstance(event.item, ResponseOutputMessage): - messages.append(cast(ResponseOutputMessageParam, item)) - elif isinstance(event.item, ResponseFunctionToolCall): - messages.append(cast(ResponseFunctionToolCallParam, item)) + yield { + "native": ResponseReasoningItem( + type="reasoning", + id=event.item.id, + summary=[], # Remove summaries + encrypted_content=event.item.encrypted_content, + ) + } elif isinstance(event, ResponseTextDeltaEvent): yield {"content": event.delta} + elif isinstance(event, ResponseReasoningSummaryTextDeltaEvent): + # OpenAI can output several reasoning summaries + # in a single ResponseReasoningItem. We split them as separate + # AssistantContent messages. Only last of them will have + # the reasoning `native` field set. + if ( + last_summary_index is not None + and event.summary_index != last_summary_index + ): + yield {"role": "assistant"} + last_summary_index = event.summary_index + yield {"thinking_content": event.delta} elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent): current_tool_call.arguments += event.delta elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent): @@ -335,16 +382,18 @@ class OpenAIBaseLLMEntity(Entity): ) ) - model_args = { - "model": options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), - "input": [], - "max_output_tokens": options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), - "top_p": options.get(CONF_TOP_P, RECOMMENDED_TOP_P), - "temperature": options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), - "user": chat_log.conversation_id, - "store": False, - "stream": True, - } + messages = _convert_content_to_param(chat_log.content) + + model_args = ResponseCreateParamsStreaming( + model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), + input=messages, + max_output_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), + top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), + temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), + user=chat_log.conversation_id, + store=False, + stream=True, + ) if tools: model_args["tools"] = tools @@ -352,7 +401,8 @@ class OpenAIBaseLLMEntity(Entity): model_args["reasoning"] = { "effort": options.get( CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT - ) + ), + "summary": "auto", } model_args["include"] = ["reasoning.encrypted_content"] @@ -361,12 +411,6 @@ class OpenAIBaseLLMEntity(Entity): "verbosity": options.get(CONF_VERBOSITY, RECOMMENDED_VERBOSITY) } - messages = [ - m - for content in chat_log.content - for m in _convert_content_to_param(content) - ] - last_content = chat_log.content[-1] # Handle attachments by adding them to the last user message @@ -399,16 +443,19 @@ class OpenAIBaseLLMEntity(Entity): # To prevent infinite loops, we limit the number of iterations for _iteration in range(MAX_TOOL_ITERATIONS): - model_args["input"] = messages - try: - result = await client.responses.create(**model_args) + stream = await client.responses.create(**model_args) - async for content in chat_log.async_add_delta_content_stream( - self.entity_id, _transform_stream(chat_log, result, messages) - ): - if not isinstance(content, conversation.AssistantContent): - messages.extend(_convert_content_to_param(content)) + messages.extend( + _convert_content_to_param( + [ + content + async for content in chat_log.async_add_delta_content_stream( + self.entity_id, _transform_stream(chat_log, stream) + ) + ] + ) + ) except openai.RateLimitError as err: LOGGER.error("Rate limited by OpenAI: %s", err) raise HomeAssistantError("Rate limited or insufficient funds") from err diff --git a/tests/components/openai_conversation/__init__.py b/tests/components/openai_conversation/__init__.py index 0cdccb6d0cf..0ca02b8f629 100644 --- a/tests/components/openai_conversation/__init__.py +++ b/tests/components/openai_conversation/__init__.py @@ -18,6 +18,10 @@ from openai.types.responses import ( ResponseOutputMessage, ResponseOutputText, ResponseReasoningItem, + ResponseReasoningSummaryPartAddedEvent, + ResponseReasoningSummaryPartDoneEvent, + ResponseReasoningSummaryTextDeltaEvent, + ResponseReasoningSummaryTextDoneEvent, ResponseStreamEvent, ResponseTextDeltaEvent, ResponseTextDoneEvent, @@ -26,6 +30,7 @@ from openai.types.responses import ( ResponseWebSearchCallSearchingEvent, ) from openai.types.responses.response_function_web_search import ActionSearch +from openai.types.responses.response_reasoning_item import Summary def create_message_item( @@ -173,9 +178,23 @@ def create_function_tool_call_item( return events -def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEvent]: +def create_reasoning_item( + id: str, + output_index: int, + reasoning_summary: list[list[str]] | list[str] | str | None = None, +) -> list[ResponseStreamEvent]: """Create a reasoning item.""" - return [ + + if reasoning_summary is None: + reasoning_summary = [[]] + elif isinstance(reasoning_summary, str): + reasoning_summary = [reasoning_summary] + if isinstance(reasoning_summary, list) and all( + isinstance(item, str) for item in reasoning_summary + ): + reasoning_summary = [reasoning_summary] + + events = [ ResponseOutputItemAddedEvent( item=ResponseReasoningItem( id=id, @@ -187,11 +206,60 @@ def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEven output_index=output_index, sequence_number=0, type="response.output_item.added", - ), + ) + ] + + for summary_index, summary in enumerate(reasoning_summary): + events.append( + ResponseReasoningSummaryPartAddedEvent( + item_id=id, + output_index=output_index, + part={"text": "", "type": "summary_text"}, + sequence_number=0, + summary_index=summary_index, + type="response.reasoning_summary_part.added", + ) + ) + events.extend( + ResponseReasoningSummaryTextDeltaEvent( + delta=delta, + item_id=id, + output_index=output_index, + sequence_number=0, + summary_index=summary_index, + type="response.reasoning_summary_text.delta", + ) + for delta in summary + ) + events.extend( + [ + ResponseReasoningSummaryTextDoneEvent( + item_id=id, + output_index=output_index, + sequence_number=0, + summary_index=summary_index, + text="".join(summary), + type="response.reasoning_summary_text.done", + ), + ResponseReasoningSummaryPartDoneEvent( + item_id=id, + output_index=output_index, + part={"text": "".join(summary), "type": "summary_text"}, + sequence_number=0, + summary_index=summary_index, + type="response.reasoning_summary_part.done", + ), + ] + ) + + events.append( ResponseOutputItemDoneEvent( item=ResponseReasoningItem( id=id, - summary=[], + summary=[ + Summary(text="".join(summary), type="summary_text") + for summary in reasoning_summary + ], type="reasoning", status=None, encrypted_content="AAABBB", @@ -200,7 +268,9 @@ def create_reasoning_item(id: str, output_index: int) -> list[ResponseStreamEven sequence_number=0, type="response.output_item.done", ), - ] + ) + + return events def create_web_search_item(id: str, output_index: int) -> list[ResponseStreamEvent]: diff --git a/tests/components/openai_conversation/snapshots/test_conversation.ambr b/tests/components/openai_conversation/snapshots/test_conversation.ambr index 93b86bd4bc1..7a03c484182 100644 --- a/tests/components/openai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/openai_conversation/snapshots/test_conversation.ambr @@ -6,6 +6,22 @@ 'content': 'Please call the test function', 'role': 'user', }), + dict({ + 'agent_id': 'conversation.openai_conversation', + 'content': None, + 'native': None, + 'role': 'assistant', + 'thinking_content': 'Thinking', + 'tool_calls': None, + }), + dict({ + 'agent_id': 'conversation.openai_conversation', + 'content': None, + 'native': ResponseReasoningItem(id='rs_A', summary=[], type='reasoning', content=None, encrypted_content='AAABBB', status=None), + 'role': 'assistant', + 'thinking_content': 'Thinking more', + 'tool_calls': None, + }), dict({ 'agent_id': 'conversation.openai_conversation', 'content': None, @@ -62,6 +78,57 @@ }), ]) # --- +# name: test_function_call.1 + list([ + dict({ + 'content': 'Please call the test function', + 'role': 'user', + 'type': 'message', + }), + dict({ + 'encrypted_content': 'AAABBB', + 'id': 'rs_A', + 'summary': list([ + dict({ + 'text': 'Thinking', + 'type': 'summary_text', + }), + dict({ + 'text': 'Thinking more', + 'type': 'summary_text', + }), + ]), + 'type': 'reasoning', + }), + dict({ + 'arguments': '{"param1": "call1"}', + 'call_id': 'call_call_1', + 'name': 'test_tool', + 'type': 'function_call', + }), + dict({ + 'call_id': 'call_call_1', + 'output': '"value1"', + 'type': 'function_call_output', + }), + dict({ + 'arguments': '{"param1": "call2"}', + 'call_id': 'call_call_2', + 'name': 'test_tool', + 'type': 'function_call', + }), + dict({ + 'call_id': 'call_call_2', + 'output': '"value2"', + 'type': 'function_call_output', + }), + dict({ + 'content': 'Cool', + 'role': 'assistant', + 'type': 'message', + }), + ]) +# --- # name: test_function_call_without_reasoning list([ dict({ diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 5abce689855..921eb39c542 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -252,7 +252,11 @@ async def test_function_call( # Initial conversation ( # Wait for the model to think - *create_reasoning_item(id="rs_A", output_index=0), + *create_reasoning_item( + id="rs_A", + output_index=0, + reasoning_summary=[["Thinking"], ["Thinking ", "more"]], + ), # First tool call *create_function_tool_call_item( id="fc_1", @@ -288,16 +292,10 @@ async def test_function_call( agent_id="conversation.openai_conversation", ) - assert mock_create_stream.call_args.kwargs["input"][2] == { - "content": None, - "id": "rs_A", - "summary": [], - "type": "reasoning", - "encrypted_content": "AAABBB", - } assert result.response.response_type == intent.IntentResponseType.ACTION_DONE # Don't test the prompt, as it's not deterministic assert mock_chat_log.content[1:] == snapshot + assert mock_create_stream.call_args.kwargs["input"][1:] == snapshot async def test_function_call_without_reasoning(