mirror of
https://github.com/home-assistant/core.git
synced 2025-09-07 22:01:34 +02:00
Add thinking and native content to chatlog (#149699)
This commit is contained in:
@@ -161,7 +161,9 @@ class AssistantContent:
|
||||
role: Literal["assistant"] = field(init=False, default="assistant")
|
||||
agent_id: str
|
||||
content: str | None = None
|
||||
thinking_content: str | None = None
|
||||
tool_calls: list[llm.ToolInput] | None = None
|
||||
native: Any = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -183,7 +185,9 @@ class AssistantContentDeltaDict(TypedDict, total=False):
|
||||
|
||||
role: Literal["assistant"]
|
||||
content: str | None
|
||||
thinking_content: str | None
|
||||
tool_calls: list[llm.ToolInput] | None
|
||||
native: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -306,6 +310,8 @@ class ChatLog:
|
||||
The keys content and tool_calls will be concatenated if they appear multiple times.
|
||||
"""
|
||||
current_content = ""
|
||||
current_thinking_content = ""
|
||||
current_native: Any = None
|
||||
current_tool_calls: list[llm.ToolInput] = []
|
||||
tool_call_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
@@ -316,6 +322,14 @@ class ChatLog:
|
||||
if "role" not in delta:
|
||||
if delta_content := delta.get("content"):
|
||||
current_content += delta_content
|
||||
if delta_thinking_content := delta.get("thinking_content"):
|
||||
current_thinking_content += delta_thinking_content
|
||||
if delta_native := delta.get("native"):
|
||||
if current_native is not None:
|
||||
raise RuntimeError(
|
||||
"Native content already set, cannot overwrite"
|
||||
)
|
||||
current_native = delta_native
|
||||
if delta_tool_calls := delta.get("tool_calls"):
|
||||
if self.llm_api is None:
|
||||
raise ValueError("No LLM API configured")
|
||||
@@ -337,11 +351,18 @@ class ChatLog:
|
||||
raise ValueError(f"Only assistant role expected. Got {delta['role']}")
|
||||
|
||||
# Yield the previous message if it has content
|
||||
if current_content or current_tool_calls:
|
||||
if (
|
||||
current_content
|
||||
or current_thinking_content
|
||||
or current_tool_calls
|
||||
or current_native
|
||||
):
|
||||
content = AssistantContent(
|
||||
agent_id=agent_id,
|
||||
content=current_content or None,
|
||||
thinking_content=current_thinking_content or None,
|
||||
tool_calls=current_tool_calls or None,
|
||||
native=current_native,
|
||||
)
|
||||
yield content
|
||||
async for tool_result in self.async_add_assistant_content(
|
||||
@@ -352,16 +373,25 @@ class ChatLog:
|
||||
self.delta_listener(self, asdict(tool_result))
|
||||
|
||||
current_content = delta.get("content") or ""
|
||||
current_thinking_content = delta.get("thinking_content") or ""
|
||||
current_tool_calls = delta.get("tool_calls") or []
|
||||
current_native = delta.get("native")
|
||||
|
||||
if self.delta_listener:
|
||||
self.delta_listener(self, delta) # type: ignore[arg-type]
|
||||
|
||||
if current_content or current_tool_calls:
|
||||
if (
|
||||
current_content
|
||||
or current_thinking_content
|
||||
or current_tool_calls
|
||||
or current_native
|
||||
):
|
||||
content = AssistantContent(
|
||||
agent_id=agent_id,
|
||||
content=current_content or None,
|
||||
thinking_content=current_thinking_content or None,
|
||||
tool_calls=current_tool_calls or None,
|
||||
native=current_native,
|
||||
)
|
||||
yield content
|
||||
async for tool_result in self.async_add_assistant_content(
|
||||
|
@@ -16,7 +16,9 @@
|
||||
dict({
|
||||
'agent_id': 'ai_task.test_task_entity',
|
||||
'content': 'Mock result',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
|
@@ -19,7 +19,9 @@
|
||||
dict({
|
||||
'agent_id': 'conversation.claude_conversation',
|
||||
'content': 'Certainly, calling it now!',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'toolu_0123456789AbCdEfGhIjKlM',
|
||||
@@ -40,7 +42,9 @@
|
||||
dict({
|
||||
'agent_id': 'conversation.claude_conversation',
|
||||
'content': 'I have successfully called the function',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
|
@@ -3,12 +3,27 @@
|
||||
list([
|
||||
])
|
||||
# ---
|
||||
# name: test_add_delta_content_stream[deltas10]
|
||||
list([
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': None,
|
||||
'native': object(
|
||||
),
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_add_delta_content_stream[deltas1]
|
||||
list([
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
@@ -18,13 +33,17 @@
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test 2',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
@@ -34,7 +53,9 @@
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': None,
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'mock-tool-call-id',
|
||||
@@ -59,7 +80,9 @@
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'mock-tool-call-id',
|
||||
@@ -84,7 +107,9 @@
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'mock-tool-call-id',
|
||||
@@ -105,7 +130,9 @@
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test 2',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
@@ -115,7 +142,9 @@
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': None,
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'mock-tool-call-id',
|
||||
@@ -149,6 +178,45 @@
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_add_delta_content_stream[deltas7]
|
||||
list([
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': None,
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': 'Test Thinking',
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_add_delta_content_stream[deltas8]
|
||||
list([
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': 'Test',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': 'Test Thinking',
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_add_delta_content_stream[deltas9]
|
||||
list([
|
||||
dict({
|
||||
'agent_id': 'mock-agent-id',
|
||||
'content': None,
|
||||
'native': dict({
|
||||
'type': 'test',
|
||||
'value': 'Test Native',
|
||||
}),
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_template_error
|
||||
dict({
|
||||
'continue_conversation': False,
|
||||
|
@@ -517,6 +517,27 @@ async def test_tool_call_exception(
|
||||
]
|
||||
},
|
||||
],
|
||||
# With thinking content
|
||||
[
|
||||
{"role": "assistant"},
|
||||
{"thinking_content": "Test Thinking"},
|
||||
],
|
||||
# With content and thinking content
|
||||
[
|
||||
{"role": "assistant"},
|
||||
{"content": "Test"},
|
||||
{"thinking_content": "Test Thinking"},
|
||||
],
|
||||
# With native content
|
||||
[
|
||||
{"role": "assistant"},
|
||||
{"native": {"type": "test", "value": "Test Native"}},
|
||||
],
|
||||
# With native object content
|
||||
[
|
||||
{"role": "assistant"},
|
||||
{"native": object()},
|
||||
],
|
||||
],
|
||||
)
|
||||
async def test_add_delta_content_stream(
|
||||
@@ -634,6 +655,20 @@ async def test_add_delta_content_stream_errors(
|
||||
):
|
||||
pass
|
||||
|
||||
# Second native content
|
||||
with pytest.raises(RuntimeError):
|
||||
async for _tool_result_content in chat_log.async_add_delta_content_stream(
|
||||
"mock-agent-id",
|
||||
stream(
|
||||
[
|
||||
{"role": "assistant"},
|
||||
{"native": "Test Native"},
|
||||
{"native": "Test Native 2"},
|
||||
]
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
async def test_chat_log_reuse(
|
||||
hass: HomeAssistant,
|
||||
|
@@ -113,7 +113,9 @@
|
||||
dict({
|
||||
'agent_id': 'conversation.gpt_3_5_turbo',
|
||||
'content': 'Hello, how can I help you?',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
@@ -128,7 +130,9 @@
|
||||
dict({
|
||||
'agent_id': 'conversation.gpt_3_5_turbo',
|
||||
'content': None,
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'call_call_1',
|
||||
@@ -149,7 +153,9 @@
|
||||
dict({
|
||||
'agent_id': 'conversation.gpt_3_5_turbo',
|
||||
'content': 'I have successfully called the function',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
|
@@ -9,7 +9,9 @@
|
||||
dict({
|
||||
'agent_id': 'conversation.openai_conversation',
|
||||
'content': None,
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'call_call_1',
|
||||
@@ -30,7 +32,9 @@
|
||||
dict({
|
||||
'agent_id': 'conversation.openai_conversation',
|
||||
'content': None,
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'call_call_2',
|
||||
@@ -51,7 +55,9 @@
|
||||
dict({
|
||||
'agent_id': 'conversation.openai_conversation',
|
||||
'content': 'Cool',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
@@ -66,7 +72,9 @@
|
||||
dict({
|
||||
'agent_id': 'conversation.openai_conversation',
|
||||
'content': None,
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'call_call_1',
|
||||
@@ -87,7 +95,9 @@
|
||||
dict({
|
||||
'agent_id': 'conversation.openai_conversation',
|
||||
'content': 'Cool',
|
||||
'native': None,
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
|
Reference in New Issue
Block a user