mirror of
https://github.com/home-assistant/core.git
synced 2025-09-08 14:21:33 +02:00
Add thinking and native content to chatlog (#149699)
This commit is contained in:
@@ -161,7 +161,9 @@ class AssistantContent:
|
|||||||
role: Literal["assistant"] = field(init=False, default="assistant")
|
role: Literal["assistant"] = field(init=False, default="assistant")
|
||||||
agent_id: str
|
agent_id: str
|
||||||
content: str | None = None
|
content: str | None = None
|
||||||
|
thinking_content: str | None = None
|
||||||
tool_calls: list[llm.ToolInput] | None = None
|
tool_calls: list[llm.ToolInput] | None = None
|
||||||
|
native: Any = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -183,7 +185,9 @@ class AssistantContentDeltaDict(TypedDict, total=False):
|
|||||||
|
|
||||||
role: Literal["assistant"]
|
role: Literal["assistant"]
|
||||||
content: str | None
|
content: str | None
|
||||||
|
thinking_content: str | None
|
||||||
tool_calls: list[llm.ToolInput] | None
|
tool_calls: list[llm.ToolInput] | None
|
||||||
|
native: Any
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -306,6 +310,8 @@ class ChatLog:
|
|||||||
The keys content and tool_calls will be concatenated if they appear multiple times.
|
The keys content and tool_calls will be concatenated if they appear multiple times.
|
||||||
"""
|
"""
|
||||||
current_content = ""
|
current_content = ""
|
||||||
|
current_thinking_content = ""
|
||||||
|
current_native: Any = None
|
||||||
current_tool_calls: list[llm.ToolInput] = []
|
current_tool_calls: list[llm.ToolInput] = []
|
||||||
tool_call_tasks: dict[str, asyncio.Task] = {}
|
tool_call_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
@@ -316,6 +322,14 @@ class ChatLog:
|
|||||||
if "role" not in delta:
|
if "role" not in delta:
|
||||||
if delta_content := delta.get("content"):
|
if delta_content := delta.get("content"):
|
||||||
current_content += delta_content
|
current_content += delta_content
|
||||||
|
if delta_thinking_content := delta.get("thinking_content"):
|
||||||
|
current_thinking_content += delta_thinking_content
|
||||||
|
if delta_native := delta.get("native"):
|
||||||
|
if current_native is not None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Native content already set, cannot overwrite"
|
||||||
|
)
|
||||||
|
current_native = delta_native
|
||||||
if delta_tool_calls := delta.get("tool_calls"):
|
if delta_tool_calls := delta.get("tool_calls"):
|
||||||
if self.llm_api is None:
|
if self.llm_api is None:
|
||||||
raise ValueError("No LLM API configured")
|
raise ValueError("No LLM API configured")
|
||||||
@@ -337,11 +351,18 @@ class ChatLog:
|
|||||||
raise ValueError(f"Only assistant role expected. Got {delta['role']}")
|
raise ValueError(f"Only assistant role expected. Got {delta['role']}")
|
||||||
|
|
||||||
# Yield the previous message if it has content
|
# Yield the previous message if it has content
|
||||||
if current_content or current_tool_calls:
|
if (
|
||||||
|
current_content
|
||||||
|
or current_thinking_content
|
||||||
|
or current_tool_calls
|
||||||
|
or current_native
|
||||||
|
):
|
||||||
content = AssistantContent(
|
content = AssistantContent(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
content=current_content or None,
|
content=current_content or None,
|
||||||
|
thinking_content=current_thinking_content or None,
|
||||||
tool_calls=current_tool_calls or None,
|
tool_calls=current_tool_calls or None,
|
||||||
|
native=current_native,
|
||||||
)
|
)
|
||||||
yield content
|
yield content
|
||||||
async for tool_result in self.async_add_assistant_content(
|
async for tool_result in self.async_add_assistant_content(
|
||||||
@@ -352,16 +373,25 @@ class ChatLog:
|
|||||||
self.delta_listener(self, asdict(tool_result))
|
self.delta_listener(self, asdict(tool_result))
|
||||||
|
|
||||||
current_content = delta.get("content") or ""
|
current_content = delta.get("content") or ""
|
||||||
|
current_thinking_content = delta.get("thinking_content") or ""
|
||||||
current_tool_calls = delta.get("tool_calls") or []
|
current_tool_calls = delta.get("tool_calls") or []
|
||||||
|
current_native = delta.get("native")
|
||||||
|
|
||||||
if self.delta_listener:
|
if self.delta_listener:
|
||||||
self.delta_listener(self, delta) # type: ignore[arg-type]
|
self.delta_listener(self, delta) # type: ignore[arg-type]
|
||||||
|
|
||||||
if current_content or current_tool_calls:
|
if (
|
||||||
|
current_content
|
||||||
|
or current_thinking_content
|
||||||
|
or current_tool_calls
|
||||||
|
or current_native
|
||||||
|
):
|
||||||
content = AssistantContent(
|
content = AssistantContent(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
content=current_content or None,
|
content=current_content or None,
|
||||||
|
thinking_content=current_thinking_content or None,
|
||||||
tool_calls=current_tool_calls or None,
|
tool_calls=current_tool_calls or None,
|
||||||
|
native=current_native,
|
||||||
)
|
)
|
||||||
yield content
|
yield content
|
||||||
async for tool_result in self.async_add_assistant_content(
|
async for tool_result in self.async_add_assistant_content(
|
||||||
|
@@ -16,7 +16,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'ai_task.test_task_entity',
|
'agent_id': 'ai_task.test_task_entity',
|
||||||
'content': 'Mock result',
|
'content': 'Mock result',
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': None,
|
'tool_calls': None,
|
||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
|
@@ -19,7 +19,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'conversation.claude_conversation',
|
'agent_id': 'conversation.claude_conversation',
|
||||||
'content': 'Certainly, calling it now!',
|
'content': 'Certainly, calling it now!',
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': list([
|
'tool_calls': list([
|
||||||
dict({
|
dict({
|
||||||
'id': 'toolu_0123456789AbCdEfGhIjKlM',
|
'id': 'toolu_0123456789AbCdEfGhIjKlM',
|
||||||
@@ -40,7 +42,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'conversation.claude_conversation',
|
'agent_id': 'conversation.claude_conversation',
|
||||||
'content': 'I have successfully called the function',
|
'content': 'I have successfully called the function',
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': None,
|
'tool_calls': None,
|
||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
|
@@ -3,12 +3,27 @@
|
|||||||
list([
|
list([
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_add_delta_content_stream[deltas10]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'content': None,
|
||||||
|
'native': object(
|
||||||
|
),
|
||||||
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
|
'tool_calls': None,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
# name: test_add_delta_content_stream[deltas1]
|
# name: test_add_delta_content_stream[deltas1]
|
||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'agent_id': 'mock-agent-id',
|
'agent_id': 'mock-agent-id',
|
||||||
'content': 'Test',
|
'content': 'Test',
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': None,
|
'tool_calls': None,
|
||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
@@ -18,13 +33,17 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'mock-agent-id',
|
'agent_id': 'mock-agent-id',
|
||||||
'content': 'Test',
|
'content': 'Test',
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': None,
|
'tool_calls': None,
|
||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'agent_id': 'mock-agent-id',
|
'agent_id': 'mock-agent-id',
|
||||||
'content': 'Test 2',
|
'content': 'Test 2',
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': None,
|
'tool_calls': None,
|
||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
@@ -34,7 +53,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'mock-agent-id',
|
'agent_id': 'mock-agent-id',
|
||||||
'content': None,
|
'content': None,
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': list([
|
'tool_calls': list([
|
||||||
dict({
|
dict({
|
||||||
'id': 'mock-tool-call-id',
|
'id': 'mock-tool-call-id',
|
||||||
@@ -59,7 +80,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'mock-agent-id',
|
'agent_id': 'mock-agent-id',
|
||||||
'content': 'Test',
|
'content': 'Test',
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': list([
|
'tool_calls': list([
|
||||||
dict({
|
dict({
|
||||||
'id': 'mock-tool-call-id',
|
'id': 'mock-tool-call-id',
|
||||||
@@ -84,7 +107,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'mock-agent-id',
|
'agent_id': 'mock-agent-id',
|
||||||
'content': 'Test',
|
'content': 'Test',
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': list([
|
'tool_calls': list([
|
||||||
dict({
|
dict({
|
||||||
'id': 'mock-tool-call-id',
|
'id': 'mock-tool-call-id',
|
||||||
@@ -105,7 +130,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'mock-agent-id',
|
'agent_id': 'mock-agent-id',
|
||||||
'content': 'Test 2',
|
'content': 'Test 2',
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': None,
|
'tool_calls': None,
|
||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
@@ -115,7 +142,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'mock-agent-id',
|
'agent_id': 'mock-agent-id',
|
||||||
'content': None,
|
'content': None,
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': list([
|
'tool_calls': list([
|
||||||
dict({
|
dict({
|
||||||
'id': 'mock-tool-call-id',
|
'id': 'mock-tool-call-id',
|
||||||
@@ -149,6 +178,45 @@
|
|||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_add_delta_content_stream[deltas7]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'content': None,
|
||||||
|
'native': None,
|
||||||
|
'role': 'assistant',
|
||||||
|
'thinking_content': 'Test Thinking',
|
||||||
|
'tool_calls': None,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_add_delta_content_stream[deltas8]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'content': 'Test',
|
||||||
|
'native': None,
|
||||||
|
'role': 'assistant',
|
||||||
|
'thinking_content': 'Test Thinking',
|
||||||
|
'tool_calls': None,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_add_delta_content_stream[deltas9]
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'agent_id': 'mock-agent-id',
|
||||||
|
'content': None,
|
||||||
|
'native': dict({
|
||||||
|
'type': 'test',
|
||||||
|
'value': 'Test Native',
|
||||||
|
}),
|
||||||
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
|
'tool_calls': None,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
# name: test_template_error
|
# name: test_template_error
|
||||||
dict({
|
dict({
|
||||||
'continue_conversation': False,
|
'continue_conversation': False,
|
||||||
|
@@ -517,6 +517,27 @@ async def test_tool_call_exception(
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
# With thinking content
|
||||||
|
[
|
||||||
|
{"role": "assistant"},
|
||||||
|
{"thinking_content": "Test Thinking"},
|
||||||
|
],
|
||||||
|
# With content and thinking content
|
||||||
|
[
|
||||||
|
{"role": "assistant"},
|
||||||
|
{"content": "Test"},
|
||||||
|
{"thinking_content": "Test Thinking"},
|
||||||
|
],
|
||||||
|
# With native content
|
||||||
|
[
|
||||||
|
{"role": "assistant"},
|
||||||
|
{"native": {"type": "test", "value": "Test Native"}},
|
||||||
|
],
|
||||||
|
# With native object content
|
||||||
|
[
|
||||||
|
{"role": "assistant"},
|
||||||
|
{"native": object()},
|
||||||
|
],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_add_delta_content_stream(
|
async def test_add_delta_content_stream(
|
||||||
@@ -634,6 +655,20 @@ async def test_add_delta_content_stream_errors(
|
|||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Second native content
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
async for _tool_result_content in chat_log.async_add_delta_content_stream(
|
||||||
|
"mock-agent-id",
|
||||||
|
stream(
|
||||||
|
[
|
||||||
|
{"role": "assistant"},
|
||||||
|
{"native": "Test Native"},
|
||||||
|
{"native": "Test Native 2"},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def test_chat_log_reuse(
|
async def test_chat_log_reuse(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
@@ -113,7 +113,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'conversation.gpt_3_5_turbo',
|
'agent_id': 'conversation.gpt_3_5_turbo',
|
||||||
'content': 'Hello, how can I help you?',
|
'content': 'Hello, how can I help you?',
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': None,
|
'tool_calls': None,
|
||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
@@ -128,7 +130,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'conversation.gpt_3_5_turbo',
|
'agent_id': 'conversation.gpt_3_5_turbo',
|
||||||
'content': None,
|
'content': None,
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': list([
|
'tool_calls': list([
|
||||||
dict({
|
dict({
|
||||||
'id': 'call_call_1',
|
'id': 'call_call_1',
|
||||||
@@ -149,7 +153,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'conversation.gpt_3_5_turbo',
|
'agent_id': 'conversation.gpt_3_5_turbo',
|
||||||
'content': 'I have successfully called the function',
|
'content': 'I have successfully called the function',
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': None,
|
'tool_calls': None,
|
||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
|
@@ -9,7 +9,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'conversation.openai_conversation',
|
'agent_id': 'conversation.openai_conversation',
|
||||||
'content': None,
|
'content': None,
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': list([
|
'tool_calls': list([
|
||||||
dict({
|
dict({
|
||||||
'id': 'call_call_1',
|
'id': 'call_call_1',
|
||||||
@@ -30,7 +32,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'conversation.openai_conversation',
|
'agent_id': 'conversation.openai_conversation',
|
||||||
'content': None,
|
'content': None,
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': list([
|
'tool_calls': list([
|
||||||
dict({
|
dict({
|
||||||
'id': 'call_call_2',
|
'id': 'call_call_2',
|
||||||
@@ -51,7 +55,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'conversation.openai_conversation',
|
'agent_id': 'conversation.openai_conversation',
|
||||||
'content': 'Cool',
|
'content': 'Cool',
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': None,
|
'tool_calls': None,
|
||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
@@ -66,7 +72,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'conversation.openai_conversation',
|
'agent_id': 'conversation.openai_conversation',
|
||||||
'content': None,
|
'content': None,
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': list([
|
'tool_calls': list([
|
||||||
dict({
|
dict({
|
||||||
'id': 'call_call_1',
|
'id': 'call_call_1',
|
||||||
@@ -87,7 +95,9 @@
|
|||||||
dict({
|
dict({
|
||||||
'agent_id': 'conversation.openai_conversation',
|
'agent_id': 'conversation.openai_conversation',
|
||||||
'content': 'Cool',
|
'content': 'Cool',
|
||||||
|
'native': None,
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
|
'thinking_content': None,
|
||||||
'tool_calls': None,
|
'tool_calls': None,
|
||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
|
Reference in New Issue
Block a user