mirror of
https://github.com/home-assistant/core.git
synced 2025-08-31 02:11:32 +02:00
Anthropic thinking content (#150341)
This commit is contained in:
@@ -2,11 +2,10 @@
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable, Iterable
|
||||
import json
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import anthropic
|
||||
from anthropic import AsyncStream
|
||||
from anthropic._types import NOT_GIVEN
|
||||
from anthropic.types import (
|
||||
InputJSONDelta,
|
||||
MessageDeltaUsage,
|
||||
@@ -17,7 +16,6 @@ from anthropic.types import (
|
||||
RawContentBlockStopEvent,
|
||||
RawMessageDeltaEvent,
|
||||
RawMessageStartEvent,
|
||||
RawMessageStopEvent,
|
||||
RedactedThinkingBlock,
|
||||
RedactedThinkingBlockParam,
|
||||
SignatureDelta,
|
||||
@@ -35,6 +33,7 @@ from anthropic.types import (
|
||||
ToolUseBlockParam,
|
||||
Usage,
|
||||
)
|
||||
from anthropic.types.message_create_params import MessageCreateParamsStreaming
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import conversation
|
||||
@@ -129,6 +128,28 @@ def _convert_content(
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(content.native, ThinkingBlock):
|
||||
messages[-1]["content"].append( # type: ignore[union-attr]
|
||||
ThinkingBlockParam(
|
||||
type="thinking",
|
||||
thinking=content.thinking_content or "",
|
||||
signature=content.native.signature,
|
||||
)
|
||||
)
|
||||
elif isinstance(content.native, RedactedThinkingBlock):
|
||||
redacted_thinking_block = RedactedThinkingBlockParam(
|
||||
type="redacted_thinking",
|
||||
data=content.native.data,
|
||||
)
|
||||
if isinstance(messages[-1]["content"], str):
|
||||
messages[-1]["content"] = [
|
||||
TextBlockParam(type="text", text=messages[-1]["content"]),
|
||||
redacted_thinking_block,
|
||||
]
|
||||
else:
|
||||
messages[-1]["content"].append( # type: ignore[attr-defined]
|
||||
redacted_thinking_block
|
||||
)
|
||||
if content.content:
|
||||
messages[-1]["content"].append( # type: ignore[union-attr]
|
||||
TextBlockParam(type="text", text=content.content)
|
||||
@@ -152,10 +173,9 @@ def _convert_content(
|
||||
return messages
|
||||
|
||||
|
||||
async def _transform_stream( # noqa: C901 - This is complex, but better to have it in one place
|
||||
async def _transform_stream(
|
||||
chat_log: conversation.ChatLog,
|
||||
result: AsyncStream[MessageStreamEvent],
|
||||
messages: list[MessageParam],
|
||||
stream: AsyncStream[MessageStreamEvent],
|
||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||
"""Transform the response stream into HA format.
|
||||
|
||||
@@ -186,31 +206,25 @@ async def _transform_stream( # noqa: C901 - This is complex, but better to have
|
||||
|
||||
Each message could contain multiple blocks of the same type.
|
||||
"""
|
||||
if result is None:
|
||||
if stream is None:
|
||||
raise TypeError("Expected a stream of messages")
|
||||
|
||||
current_message: MessageParam | None = None
|
||||
current_block: (
|
||||
TextBlockParam
|
||||
| ToolUseBlockParam
|
||||
| ThinkingBlockParam
|
||||
| RedactedThinkingBlockParam
|
||||
| None
|
||||
) = None
|
||||
current_tool_block: ToolUseBlockParam | None = None
|
||||
current_tool_args: str
|
||||
input_usage: Usage | None = None
|
||||
has_content = False
|
||||
has_native = False
|
||||
|
||||
async for response in result:
|
||||
async for response in stream:
|
||||
LOGGER.debug("Received response: %s", response)
|
||||
|
||||
if isinstance(response, RawMessageStartEvent):
|
||||
if response.message.role != "assistant":
|
||||
raise ValueError("Unexpected message role")
|
||||
current_message = MessageParam(role=response.message.role, content=[])
|
||||
input_usage = response.message.usage
|
||||
elif isinstance(response, RawContentBlockStartEvent):
|
||||
if isinstance(response.content_block, ToolUseBlock):
|
||||
current_block = ToolUseBlockParam(
|
||||
current_tool_block = ToolUseBlockParam(
|
||||
type="tool_use",
|
||||
id=response.content_block.id,
|
||||
name=response.content_block.name,
|
||||
@@ -218,75 +232,64 @@ async def _transform_stream( # noqa: C901 - This is complex, but better to have
|
||||
)
|
||||
current_tool_args = ""
|
||||
elif isinstance(response.content_block, TextBlock):
|
||||
current_block = TextBlockParam(
|
||||
type="text", text=response.content_block.text
|
||||
)
|
||||
yield {"role": "assistant"}
|
||||
if has_content:
|
||||
yield {"role": "assistant"}
|
||||
has_native = False
|
||||
has_content = True
|
||||
if response.content_block.text:
|
||||
yield {"content": response.content_block.text}
|
||||
elif isinstance(response.content_block, ThinkingBlock):
|
||||
current_block = ThinkingBlockParam(
|
||||
type="thinking",
|
||||
thinking=response.content_block.thinking,
|
||||
signature=response.content_block.signature,
|
||||
)
|
||||
if has_native:
|
||||
yield {"role": "assistant"}
|
||||
has_native = False
|
||||
has_content = False
|
||||
elif isinstance(response.content_block, RedactedThinkingBlock):
|
||||
current_block = RedactedThinkingBlockParam(
|
||||
type="redacted_thinking", data=response.content_block.data
|
||||
)
|
||||
LOGGER.debug(
|
||||
"Some of Claude’s internal reasoning has been automatically "
|
||||
"encrypted for safety reasons. This doesn’t affect the quality of "
|
||||
"responses"
|
||||
)
|
||||
if has_native:
|
||||
yield {"role": "assistant"}
|
||||
has_native = False
|
||||
has_content = False
|
||||
yield {"native": response.content_block}
|
||||
has_native = True
|
||||
elif isinstance(response, RawContentBlockDeltaEvent):
|
||||
if current_block is None:
|
||||
raise ValueError("Unexpected delta without a block")
|
||||
if isinstance(response.delta, InputJSONDelta):
|
||||
current_tool_args += response.delta.partial_json
|
||||
elif isinstance(response.delta, TextDelta):
|
||||
text_block = cast(TextBlockParam, current_block)
|
||||
text_block["text"] += response.delta.text
|
||||
yield {"content": response.delta.text}
|
||||
elif isinstance(response.delta, ThinkingDelta):
|
||||
thinking_block = cast(ThinkingBlockParam, current_block)
|
||||
thinking_block["thinking"] += response.delta.thinking
|
||||
yield {"thinking_content": response.delta.thinking}
|
||||
elif isinstance(response.delta, SignatureDelta):
|
||||
thinking_block = cast(ThinkingBlockParam, current_block)
|
||||
thinking_block["signature"] += response.delta.signature
|
||||
yield {
|
||||
"native": ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking="",
|
||||
signature=response.delta.signature,
|
||||
)
|
||||
}
|
||||
has_native = True
|
||||
elif isinstance(response, RawContentBlockStopEvent):
|
||||
if current_block is None:
|
||||
raise ValueError("Unexpected stop event without a current block")
|
||||
if current_block["type"] == "tool_use":
|
||||
# tool block
|
||||
if current_tool_block is not None:
|
||||
tool_args = json.loads(current_tool_args) if current_tool_args else {}
|
||||
current_block["input"] = tool_args
|
||||
current_tool_block["input"] = tool_args
|
||||
yield {
|
||||
"tool_calls": [
|
||||
llm.ToolInput(
|
||||
id=current_block["id"],
|
||||
tool_name=current_block["name"],
|
||||
id=current_tool_block["id"],
|
||||
tool_name=current_tool_block["name"],
|
||||
tool_args=tool_args,
|
||||
)
|
||||
]
|
||||
}
|
||||
elif current_block["type"] == "thinking":
|
||||
# thinking block
|
||||
LOGGER.debug("Thinking: %s", current_block["thinking"])
|
||||
|
||||
if current_message is None:
|
||||
raise ValueError("Unexpected stop event without a current message")
|
||||
current_message["content"].append(current_block) # type: ignore[union-attr]
|
||||
current_block = None
|
||||
current_tool_block = None
|
||||
elif isinstance(response, RawMessageDeltaEvent):
|
||||
if (usage := response.usage) is not None:
|
||||
chat_log.async_trace(_create_token_stats(input_usage, usage))
|
||||
if response.delta.stop_reason == "refusal":
|
||||
raise HomeAssistantError("Potential policy violation detected")
|
||||
elif isinstance(response, RawMessageStopEvent):
|
||||
if current_message is not None:
|
||||
messages.append(current_message)
|
||||
current_message = None
|
||||
|
||||
|
||||
def _create_token_stats(
|
||||
@@ -351,48 +354,48 @@ class AnthropicBaseLLMEntity(Entity):
|
||||
thinking_budget = options.get(CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET)
|
||||
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
|
||||
|
||||
model_args = MessageCreateParamsStreaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||
system=system.content,
|
||||
stream=True,
|
||||
)
|
||||
if tools:
|
||||
model_args["tools"] = tools
|
||||
if (
|
||||
model.startswith(tuple(THINKING_MODELS))
|
||||
and thinking_budget >= MIN_THINKING_BUDGET
|
||||
):
|
||||
model_args["thinking"] = ThinkingConfigEnabledParam(
|
||||
type="enabled", budget_tokens=thinking_budget
|
||||
)
|
||||
else:
|
||||
model_args["thinking"] = ThinkingConfigDisabledParam(type="disabled")
|
||||
model_args["temperature"] = options.get(
|
||||
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
|
||||
)
|
||||
|
||||
# To prevent infinite loops, we limit the number of iterations
|
||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||
model_args = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": tools or NOT_GIVEN,
|
||||
"max_tokens": options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
|
||||
"system": system.content,
|
||||
"stream": True,
|
||||
}
|
||||
if (
|
||||
model.startswith(tuple(THINKING_MODELS))
|
||||
and thinking_budget >= MIN_THINKING_BUDGET
|
||||
):
|
||||
model_args["thinking"] = ThinkingConfigEnabledParam(
|
||||
type="enabled", budget_tokens=thinking_budget
|
||||
)
|
||||
else:
|
||||
model_args["thinking"] = ThinkingConfigDisabledParam(type="disabled")
|
||||
model_args["temperature"] = options.get(
|
||||
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
|
||||
)
|
||||
|
||||
try:
|
||||
stream = await client.messages.create(**model_args)
|
||||
|
||||
messages.extend(
|
||||
_convert_content(
|
||||
[
|
||||
content
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
self.entity_id,
|
||||
_transform_stream(chat_log, stream),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
except anthropic.AnthropicError as err:
|
||||
raise HomeAssistantError(
|
||||
f"Sorry, I had a problem talking to Anthropic: {err}"
|
||||
) from err
|
||||
|
||||
messages.extend(
|
||||
_convert_content(
|
||||
[
|
||||
content
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
self.entity_id,
|
||||
_transform_stream(chat_log, stream, messages),
|
||||
)
|
||||
if not isinstance(content, conversation.AssistantContent)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if not chat_log.unresponded_tool_results:
|
||||
break
|
||||
|
@@ -18,10 +18,26 @@
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.claude_conversation',
|
||||
'content': 'Certainly, calling it now!',
|
||||
'native': None,
|
||||
'content': None,
|
||||
'native': ThinkingBlock(signature='ErUBCkYIARgCIkCYXaVNJShe3A86Hp7XUzh9YsCYBbJTbQsrklTAPtJ2sP/NoB6tSzpK/nTL6CjSo2R6n0KNBIg5MH6asM2R/kmaEgyB/X1FtZq5OQAC7jUaDEPWCdcwGQ4RaBy5wiIwmRxExIlDhoY6tILoVPnOExkC/0igZxHEwxK8RU/fmw0b+o+TwAarzUitwzbo21E5Kh3pa3I6yqVROf1t2F8rFocNUeCegsWV/ytwYV+ayA==', thinking='', type='thinking'),
|
||||
'role': 'assistant',
|
||||
'thinking_content': 'The user asked me to call a test function.Is it a test? What would the function do? Would it violate any privacy or security policies?',
|
||||
'tool_calls': None,
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.claude_conversation',
|
||||
'content': None,
|
||||
'native': RedactedThinkingBlock(data='EroBCkYIARgCKkBJDytPJhw//4vy3t7aE+LfIkxvkAh51cBPrAvBCo6AjgI57Zt9KWPnUVV50OQJ0KZzUFoGZG5sxg95zx4qMwkoEgz43Su3myJKckvj03waDBZLIBSeoAeRUeVsJCIwQ5edQN0sa+HNeB/KUBkoMUwV+IT0eIhcpFxnILdvxUAKM4R1o4KG3x+yO0eo/kyOKiKfrCPFQhvBVmTZPFhgA2Ow8L9gGDVipcz6x3Uu9YETGEny', type='redacted_thinking'),
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.claude_conversation',
|
||||
'content': 'Certainly, calling it now!',
|
||||
'native': ThinkingBlock(signature='ErUBCkYIARgCIkCYXaVNJShe3A86Hp7XUzh9YsCYBbJTbQsrklTAPtJ2sP/NoB6tSzpK/nTL6CjSo2R6n0KNBIg5MH6asM2R/kmaEgyB/X1FtZq5OQAC7jUaDEPWCdcwGQ4RaBy5wiIwmRxExIlDhoY6tILoVPnOExkC/0igZxHEwxK8RU/fmw0b+o+TwAarzUitwzbo21E5Kh3pa3I6yqVROf1t2F8rFocNUeCegsWV/ytwYV+ayA==', thinking='', type='thinking'),
|
||||
'role': 'assistant',
|
||||
'thinking_content': "Okay, let's give it a shot. Will I pass the test?",
|
||||
'tool_calls': list([
|
||||
dict({
|
||||
'id': 'toolu_0123456789AbCdEfGhIjKlM',
|
||||
@@ -321,6 +337,39 @@
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_redacted_thinking
|
||||
list([
|
||||
dict({
|
||||
'attachments': None,
|
||||
'content': 'ANTHROPIC_MAGIC_STRING_TRIGGER_REDACTED_THINKING_46C9A13E193C177646C7398A98432ECCCE4C1253D5E2D82641AC0E52CC2876CB',
|
||||
'role': 'user',
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.claude_conversation',
|
||||
'content': None,
|
||||
'native': RedactedThinkingBlock(data='EroBCkYIARgCKkBJDytPJhw//4vy3t7aE+LfIkxvkAh51cBPrAvBCo6AjgI57Zt9KWPnUVV50OQJ0KZzUFoGZG5sxg95zx4qMwkoEgz43Su3myJKckvj03waDBZLIBSeoAeRUeVsJCIwQ5edQN0sa+HNeB/KUBkoMUwV+IT0eIhcpFxnILdvxUAKM4R1o4KG3x+yO0eo/kyOKiKfrCPFQhvBVmTZPFhgA2Ow8L9gGDVipcz6x3Uu9YETGEny', type='redacted_thinking'),
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.claude_conversation',
|
||||
'content': None,
|
||||
'native': RedactedThinkingBlock(data='EroBCkYIARgCKkBJDytPJhw//4vy3t7aE+LfIkxvkAh51cBPrAvBCo6AjgI57Zt9KWPnUVV50OQJ0KZzUFoGZG5sxg95zx4qMwkoEgz43Su3myJKckvj03waDBZLIBSeoAeRUeVsJCIwQ5edQN0sa+HNeB/KUBkoMUwV+IT0eIhcpFxnILdvxUAKM4R1o4KG3x+yO0eo/kyOKiKfrCPFQhvBVmTZPFhgA2Ow8L9gGDVipcz6x3Uu9YETGEny', type='redacted_thinking'),
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
dict({
|
||||
'agent_id': 'conversation.claude_conversation',
|
||||
'content': 'How can I help you today?',
|
||||
'native': RedactedThinkingBlock(data='EroBCkYIARgCKkBJDytPJhw//4vy3t7aE+LfIkxvkAh51cBPrAvBCo6AjgI57Zt9KWPnUVV50OQJ0KZzUFoGZG5sxg95zx4qMwkoEgz43Su3myJKckvj03waDBZLIBSeoAeRUeVsJCIwQ5edQN0sa+HNeB/KUBkoMUwV+IT0eIhcpFxnILdvxUAKM4R1o4KG3x+yO0eo/kyOKiKfrCPFQhvBVmTZPFhgA2Ow8L9gGDVipcz6x3Uu9YETGEny', type='redacted_thinking'),
|
||||
'role': 'assistant',
|
||||
'thinking_content': None,
|
||||
'tool_calls': None,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_unknown_hass_api
|
||||
dict({
|
||||
'continue_conversation': False,
|
||||
|
@@ -728,6 +728,7 @@ async def test_redacted_thinking(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_with_extended_thinking: MockConfigEntry,
|
||||
mock_init_component,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test extended thinking with redacted thinking blocks."""
|
||||
with patch(
|
||||
@@ -756,8 +757,8 @@ async def test_redacted_thinking(
|
||||
chat_log = hass.data.get(conversation.chat_log.DATA_CHAT_LOGS).get(
|
||||
result.conversation_id
|
||||
)
|
||||
assert len(chat_log.content) == 3
|
||||
assert chat_log.content[2].content == "How can I help you today?"
|
||||
# Don't test the prompt because it's not deterministic
|
||||
assert chat_log.content[1:] == snapshot
|
||||
|
||||
|
||||
@patch("homeassistant.components.anthropic.entity.llm.AssistAPI._async_get_tools")
|
||||
|
Reference in New Issue
Block a user