Compare commits

...

14 Commits

Author SHA1 Message Date
Paulus Schoutsen 77fe61db3e Load the llm integration via the conversation dependency
The Assist API's built-in tools (date/time, calendar, todo, script, live
context, intents) now come from the llm integration's tool platform, so the
llm integration must be loaded for an Assist LLM API to expose them. Add llm
to conversation's dependencies so it loads wherever the conversation stack
(and thus any LLM agent) is active.

This makes the registry tools reach Assist consumers, which surfaces in
anthropic's test_extended_thinking_tool_call: it patched
AssistAPI._async_get_tools to control the tool set, but tools now also come
from the registry. Empty the registry in that test so its tool set stays the
single mock tool.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-16 01:05:14 -04:00
Paulus Schoutsen ac3a1493aa Register intent LLM tools from the llm integration
Relocate the intent-tool wrapping (and its two prompt fragments,
DEVICE_CONTROL_TOOL_USAGE_PROMPT and the no-timer note) out of AssistAPI
into a core provider registered by the llm integration. This is a
behavior-preserving relocation: every intent that was wrapped before is
still wrapped, with byte-identical tool names and prompt lines.

The new homeassistant/components/llm/intents.py owns IGNORE_INTENTS, the
timer-intent set, the slugify cache and the device-control prompt, and
registers an intent_tools provider in the llm integration's async_setup
(before GetDateTime/live-context so the tool order is unchanged). The
provider reproduces AssistAPI's exposure/timer filtering exactly and only
emits its prompt when entities are actually exposed.

AssistAPI._async_get_tools now returns [] and the api_prompt no longer
hardcodes the device-control or no-timer fragments; the now-unused imports
and helpers are removed. The order-independent parity snapshot
(test_assist_api_snapshot) stays green without regeneration.

llm gains an after_dependencies on intent (its provider uses intent
component helpers, which are safe without intent being set up). mcp_server,
which consumes the assist API directly, now depends on llm so its
registered tools are available.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-16 00:56:01 -04:00
Paulus Schoutsen fbda9f4fc3 Set up llm integration in dynamic time injection test
test_dynamic_time_injection case 3 relies on the Assist API providing the
GetDateTime tool (which suppresses the dynamic time prompt). Since that tool
now comes from the llm integration's registry rather than being hardcoded in
AssistAPI, the test must set up the llm integration; otherwise GetDateTime is
absent and the time prompt is injected, failing the assertion.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-16 00:32:39 -04:00
Paulus Schoutsen 5802cada22 Register live-context LLM tool from the llm integration
Move GetLiveContextTool and its DYNAMIC_CONTEXT_PROMPT prompt fragment out of
AssistAPI into the llm integration. The tool is now contributed by a registered
provider (with its prompt via LLMTools(prompt=...)) that returns it only when
entities are exposed, reproducing the previous exposure-dependent behaviour.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-16 00:01:03 -04:00
Paulus Schoutsen f357dd3d49 Await tool platform discovery in calendar/todo LLM tests
test_calendar_get_events_tool and test_todo_get_items_tool set up their
integration last and then called async_get_api before the llm integration's
EVENT_COMPONENT_LOADED-driven platform discovery had registered the provider,
so they passed in the full-file run (later setups cycled the loop) but failed
in isolation / under randomized ordering. Add async_block_till_done() before
async_get_api so discovery completes deterministically.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-15 23:53:03 -04:00
Paulus Schoutsen 1eca57a01b Register script LLM tools from the script platform
Move ScriptTool (and its ActionTool base, the _get_cached_action_parameters
helper and the action-parameters cache it owns) out of AssistAPI into a
script/llm.py tool platform discovered by the llm integration, mirroring
calendar/llm.py and todo/llm.py. ActionTool was used exclusively by
ScriptTool, so the whole unit moves together; helpers/llm.py keeps only the
SCRIPT_DOMAIN import it still needs for _get_exposed_entities.

The provider reuses llm.async_get_exposed_entities to reproduce the exact
exposed-script entities AssistAPI fed the tool, keeping the emitted
ScriptTools byte-identical. The parity snapshot stays unchanged: script was
already set up in test_assist_api_snapshot, so the platform is discovered
without surfacing any new tools.

The ScriptTool tests in test_llm.py now set up the llm component (so the
platform is discovered) and reference the tool and cache via the script
component root (script.llm) to satisfy the component-root import rule.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-15 23:44:14 -04:00
Paulus Schoutsen 0dde553653 Register todo LLM tool from the todo platform
Move TodoGetItemsTool out of AssistAPI into a todo/llm.py tool platform
discovered by the llm integration, mirroring calendar/llm.py. The provider
reuses llm.async_get_exposed_entities to reproduce the exact exposed
to-do list names AssistAPI fed the tool, keeping todo_get_items
byte-identical.

The parity snapshot now sets up the todo component (required to load the
platform), which also registers todo's intent handlers. This additively
surfaces the HassListAddItem/HassListCompleteItem/HassListRemoveItem
intent tools in the snapshot; the moved todo_get_items tool and the
prompt are unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-15 23:30:17 -04:00
Paulus Schoutsen 5943b70577 Register calendar LLM tool from the calendar platform
Move CalendarGetEventsTool out of AssistAPI into a calendar/llm.py tool
platform discovered by the llm integration. Add a public
async_get_exposed_entities helper so the provider reproduces the exact
exposed-calendar names AssistAPI fed the tool, keeping the parity snapshot
byte-identical.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-15 23:19:51 -04:00
Paulus Schoutsen 99ee92029d Register GetDateTime LLM tool from the llm integration
Move the GetDateTimeTool class out of AssistAPI's hardcoded built-in list
and register it from the llm integration's async_setup via the tool
registry. Net behavior is unchanged: GetDateTime is still exposed by the
Assist API, it just comes from the registry now.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-15 23:07:45 -04:00
Paulus Schoutsen 717acdf771 Make Assist API parity snapshot order-independent
The v1 tool-platform refactor moves built-in tools and prompt fragments out
of AssistAPI into per-integration registrations, which changes their order
but not their content. Compare the tool set as a name-keyed mapping and the
prompt as a set of lines, so each migration step verifies same-tools /
same-content / same-prompt regardless of ordering. Final ordering parity with
dev is checked separately at the end of the refactor.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-15 23:01:59 -04:00
Paulus Schoutsen e76d83241e Source AssistAPI tools and prompt from the tool registry
AssistAPI.async_get_api_instance now also pulls registered tool
providers for the assist API, appending their tools and prompt
fragments to the hardcoded built-ins. Purely additive: with no
registrations, behaviour is unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-15 22:43:41 -04:00
Paulus Schoutsen 6024f96d56 Add llm integration with tool platform discovery
New system integration that owns the LLM tools platform: its async_setup
drives async_process_integration_platforms(hass, "llm", ...) so integrations
can ship an <integration>/llm.py with an async_setup_tools hook to register
tools, mirroring the intent helper/integration split. The framework (registry,
Tool, APIs) stays in homeassistant.helpers.llm. No tools are registered yet,
so behavior is unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-15 22:39:05 -04:00
Paulus Schoutsen 5eec6167cc Add LLM tool registration API and registry
Introduce a per-tool registration layer in helpers/llm.py: an LLMTools
result (tools + optional prompt fragment), a provider callback type, and
async_register_tool_provider / async_register_tool registering into one or
more API ids. Providers are stored in a registry and merged by
_async_get_registered_tools. Nothing consumes the registry yet, so there is
no behavior change (the Assist parity snapshot is unchanged); AssistAPI is
wired to it in a later step.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-15 22:22:31 -04:00
Paulus Schoutsen 184025cfb4 Add Assist API behavior-parity snapshot test
Capture the assembled Assist prompt and the full serialized tool set
(name, description, parameters) for a scenario exercising every built-in
tool type — intents, timer tools, calendar/todo/script tools, GetDateTime
and GetLiveContext. This is the parity net for the upcoming LLM tool
platform refactor: the snapshot must stay identical as tools and intents
move out of AssistAPI into per-integration platforms.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-15 22:19:09 -04:00
19 changed files with 1816 additions and 654 deletions
Generated
+2
View File
@@ -1022,6 +1022,8 @@ CLAUDE.md @home-assistant/core
/tests/components/litterrobot/ @natekspencer @tkdrob
/homeassistant/components/livisi/ @StefanIacobLivisi @planbnet
/tests/components/livisi/ @StefanIacobLivisi @planbnet
/homeassistant/components/llm/ @home-assistant/core
/tests/components/llm/ @home-assistant/core
/homeassistant/components/local_calendar/ @allenporter
/tests/components/local_calendar/ @allenporter
/homeassistant/components/local_ip/ @issacg
+108
View File
@@ -0,0 +1,108 @@
"""LLM tools for the calendar integration."""
from datetime import timedelta
from typing import cast
import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import intent, llm
from homeassistant.util import dt as dt_util
from homeassistant.util.json import JsonObjectType
from . import SERVICE_GET_EVENTS
from .const import DOMAIN
async def async_setup_tools(hass: HomeAssistant) -> None:
"""Set up the calendar LLM tools."""
llm.async_register_tool_provider(
hass, _calendar_tools, apis={llm.LLM_API_ASSIST: None}
)
@callback
def _calendar_tools(hass: HomeAssistant, llm_context: llm.LLMContext) -> llm.LLMTools:
"""Return the calendar tools for the exposed calendars."""
if llm_context.assistant is None:
return llm.LLMTools(tools=[])
exposed = llm.async_get_exposed_entities(
hass, llm_context.assistant, include_state=False
)
if not exposed[DOMAIN]:
return llm.LLMTools(tools=[])
names = []
for info in exposed[DOMAIN].values():
names.extend(info["names"].split(", "))
return llm.LLMTools(tools=[CalendarGetEventsTool(names)])
class CalendarGetEventsTool(llm.Tool):
"""LLM Tool allowing querying a calendar."""
name = "calendar_get_events"
description = (
"Get events from a calendar. "
"When asked if something happens, search the whole week. "
"Results are RFC 5545 which means 'end' is exclusive."
)
def __init__(self, calendars: list[str]) -> None:
"""Init the get events tool."""
self.parameters = vol.Schema(
{
vol.Required("calendar"): vol.In(calendars),
vol.Required("range"): vol.In(["today", "week"]),
}
)
async def async_call(
self,
hass: HomeAssistant,
tool_input: llm.ToolInput,
llm_context: llm.LLMContext,
) -> JsonObjectType:
"""Query a calendar."""
data = self.parameters(tool_input.tool_args)
result = intent.async_match_targets(
hass,
intent.MatchTargetsConstraints(
name=data["calendar"],
domains=[DOMAIN],
assistant=llm_context.assistant,
),
)
if not result.is_match:
return {"success": False, "error": "Calendar not found"}
entity_id = result.states[0].entity_id
if data["range"] == "today":
start = dt_util.now()
end = dt_util.start_of_local_day() + timedelta(days=1)
elif data["range"] == "week":
start = dt_util.now()
end = dt_util.start_of_local_day() + timedelta(days=7)
service_data = {
"entity_id": entity_id,
"start_date_time": start.isoformat(),
"end_date_time": end.isoformat(),
}
service_result = await hass.services.async_call(
DOMAIN,
SERVICE_GET_EVENTS,
service_data,
context=llm_context.context,
blocking=True,
return_response=True,
)
events = [
event if "T" in event["start"] else {**event, "all_day": True}
for event in cast(dict, service_result)[entity_id]["events"]
]
return {"success": True, "result": events}
@@ -2,7 +2,7 @@
"domain": "conversation",
"name": "Conversation",
"codeowners": ["@home-assistant/core", "@synesthesiam", "@arturpragacz"],
"dependencies": ["http", "intent"],
"dependencies": ["http", "intent", "llm"],
"documentation": "https://www.home-assistant.io/integrations/conversation",
"integration_type": "entity",
"quality_scale": "internal",
+70
View File
@@ -0,0 +1,70 @@
"""The LLM integration.
Owns the LLM tools platform: integrations contribute tools to the LLM APIs
through an ``<integration>/llm.py`` platform with an ``async_setup_tools`` hook,
discovered here. The framework (registry, ``Tool``, the APIs) lives in
``homeassistant.helpers.llm``; this integration owns the lifecycle, mirroring the
``intent`` helper/integration split.
"""
from typing import Protocol
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, llm
from homeassistant.helpers.integration_platform import (
async_process_integration_platforms,
)
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
from .intents import intent_tools
from .tools import DYNAMIC_CONTEXT_PROMPT, GetDateTimeTool, GetLiveContextTool
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
class LLMToolsPlatformProtocol(Protocol):
"""Define the format that LLM tools platforms can have."""
async def async_setup_tools(self, hass: HomeAssistant) -> None:
"""Set up the integration's LLM tools."""
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the LLM integration."""
llm.async_register_tool_provider(
hass, intent_tools, apis={llm.LLM_API_ASSIST: None}
)
llm.async_register_tool(hass, GetDateTimeTool(), apis={llm.LLM_API_ASSIST: None})
llm.async_register_tool_provider(
hass, _live_context_tools, apis={llm.LLM_API_ASSIST: None}
)
await async_process_integration_platforms(
hass, DOMAIN, _async_process_llm_tools_platform
)
return True
@callback
def _live_context_tools(
hass: HomeAssistant, llm_context: llm.LLMContext
) -> llm.LLMTools:
"""Return the live-context tool and its prompt when entities are exposed."""
if llm_context.assistant is None:
return llm.LLMTools(tools=[])
exposed = llm.async_get_exposed_entities(
hass, llm_context.assistant, include_state=False
)
exposed_domains = {info["domain"] for info in exposed["entities"].values()}
if not exposed_domains:
return llm.LLMTools(tools=[])
return llm.LLMTools(tools=[GetLiveContextTool()], prompt=DYNAMIC_CONTEXT_PROMPT)
async def _async_process_llm_tools_platform(
hass: HomeAssistant, domain: str, platform: LLMToolsPlatformProtocol
) -> None:
"""Register the LLM tools of an integration."""
await platform.async_setup_tools(hass)
+3
View File
@@ -0,0 +1,3 @@
"""Constants for the LLM integration."""
DOMAIN = "llm"
+88
View File
@@ -0,0 +1,88 @@
"""LLM tools wrapping exposed intents."""
from functools import cache, partial
import slugify as unicode_slug
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
from homeassistant.components.intent import async_device_supports_timers
from homeassistant.components.weather import INTENT_GET_WEATHER
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import intent, llm
DEVICE_CONTROL_TOOL_USAGE_PROMPT = (
"When controlling Home Assistant always call the intent tools. "
"Use HassTurnOn to lock and HassTurnOff to unlock a lock. "
"When controlling a device, prefer passing just name and domain. "
"When controlling an area, prefer passing just area name and domain."
)
IGNORE_INTENTS = {
intent.INTENT_GET_TEMPERATURE,
INTENT_GET_WEATHER,
INTENT_OPEN_COVER, # deprecated
INTENT_CLOSE_COVER, # deprecated
intent.INTENT_GET_STATE,
intent.INTENT_NEVERMIND,
intent.INTENT_TOGGLE,
intent.INTENT_GET_CURRENT_DATE,
intent.INTENT_GET_CURRENT_TIME,
intent.INTENT_RESPOND,
}
TIMER_INTENTS = {
intent.INTENT_START_TIMER,
intent.INTENT_CANCEL_TIMER,
intent.INTENT_INCREASE_TIMER,
intent.INTENT_DECREASE_TIMER,
intent.INTENT_PAUSE_TIMER,
intent.INTENT_UNPAUSE_TIMER,
intent.INTENT_TIMER_STATUS,
}
_slugify = cache(partial(unicode_slug.slugify, separator="_", lowercase=False))
@callback
def intent_tools(hass: HomeAssistant, llm_context: llm.LLMContext) -> llm.LLMTools:
"""Return the intent tools and their prompt for the exposed entities."""
exposed = (
llm.async_get_exposed_entities(hass, llm_context.assistant, include_state=False)
if llm_context.assistant
else None
)
ignore = IGNORE_INTENTS
if not llm_context.device_id or not async_device_supports_timers(
hass, llm_context.device_id
):
ignore = ignore | TIMER_INTENTS
handlers = [
handler
for handler in intent.async_get(hass)
if handler.intent_type not in ignore
]
if exposed is not None:
exposed_domains = {info["domain"] for info in exposed["entities"].values()}
handlers = [
handler
for handler in handlers
if handler.platforms is None or handler.platforms & exposed_domains
]
tools: list[llm.Tool] = [
llm.IntentTool(_slugify(handler.intent_type), handler) for handler in handlers
]
prompt = None
if exposed and exposed["entities"]:
parts = [DEVICE_CONTROL_TOOL_USAGE_PROMPT]
if not llm_context.device_id or not async_device_supports_timers(
hass, llm_context.device_id
):
parts.append("This device is not able to start timers.")
prompt = "\n".join(parts)
return llm.LLMTools(tools=tools, prompt=prompt)
@@ -0,0 +1,10 @@
{
"domain": "llm",
"name": "LLM",
"after_dependencies": ["intent"],
"codeowners": ["@home-assistant/core"],
"documentation": "https://www.home-assistant.io/integrations/llm",
"integration_type": "system",
"iot_class": "calculated",
"quality_scale": "internal"
}
+220
View File
@@ -0,0 +1,220 @@
"""Tools for the LLM integration."""
import voluptuous as vol
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv, intent
from homeassistant.helpers.llm import (
NO_ENTITIES_PROMPT,
LLMContext,
Tool,
ToolInput,
async_get_exposed_entities,
)
from homeassistant.util import dt as dt_util, yaml as yaml_util
from homeassistant.util.json import JsonObjectType
DYNAMIC_CONTEXT_PROMPT = (
"You ARE equipped to answer questions about the"
" current state of\n"
"the home using the `GetLiveContext` tool."
" This is a primary function."
" Do not state you lack the\n"
"functionality if the question requires live data.\n"
"If the user asks about device existence/type"
' (e.g., "Do I have lights in the bedroom?"):'
" Answer\n"
"from the static context below.\n"
"If the user asks about the CURRENT state, value,"
' or mode (e.g., "Is the lock locked?",\n'
'"Is the fan on?",'
' "What mode is the thermostat in?",'
' "What is the temperature outside?"):\n'
" 1. Recognize this requires live data.\n"
" 2. You MUST call `GetLiveContext`."
" This tool will provide the needed real-time"
" information (like temperature from the local"
" weather, lock status, etc.).\n"
" 3. Use the tool's response** to answer the"
" user accurately"
' (e.g., "The temperature outside is'
' [value from tool].").\n'
"For general knowledge questions not about the"
" home: Answer truthfully from internal"
" knowledge.\n"
)
class GetDateTimeTool(Tool):
"""Tool for getting the current date and time."""
name = "GetDateTime"
description = "Provides the current date and time."
async def async_call(
self,
hass: HomeAssistant,
tool_input: ToolInput,
llm_context: LLMContext,
) -> JsonObjectType:
"""Get the current date and time."""
now = dt_util.now()
return {
"success": True,
"result": {
"date": now.strftime("%Y-%m-%d"),
"time": now.strftime("%H:%M:%S"),
"timezone": now.strftime("%Z"),
"weekday": now.strftime("%A"),
},
}
def _live_context_match_error(
match_result: intent.MatchTargetsResult,
name_filter: str | None,
area_filter: str | None,
domain_filter: list[str] | None,
) -> str:
"""Build an actionable error message for a failed GetLiveContext match."""
reason = match_result.no_match_reason
if reason is intent.MatchFailedReason.INVALID_AREA:
return f"Area '{match_result.no_match_name}' does not exist"
if reason is intent.MatchFailedReason.NAME:
return f"No exposed entities matched name '{name_filter}'"
if reason is intent.MatchFailedReason.AREA:
return f"No exposed entities found in area '{area_filter}'"
if reason is intent.MatchFailedReason.DOMAIN:
domains = ", ".join(domain_filter) if domain_filter else ""
return f"No exposed entities found in domain(s): {domains}"
return "No entities matched the provided filter"
class GetLiveContextTool(Tool):
"""Tool for getting the current state of exposed entities.
This returns state for all entities that have been exposed to
the assistant. This is different than the GetState intent, which
returns state for entities based on intent parameters.
"""
name = "GetLiveContext"
description = (
"Provides real-time information about the"
" CURRENT state, value, or mode of devices,"
" sensors, entities, or areas. "
"Use this tool for: "
"1. Answering questions about current"
" conditions (e.g., 'Is the light on?'). "
"2. As the first step in conditional actions"
" (e.g., 'If the weather is rainy, turn off"
" sprinklers' requires checking the weather"
" first). "
"You may filter for devices by name, domain,"
" and area, including combining those"
" filters. "
"Prefer filtering by domain when searching"
" for multiple devices of the same type."
)
parameters = vol.Schema(
{
vol.Optional(
"name",
description="Filter entities by name or alias (case-insensitive).",
): cv.string,
vol.Optional(
"domain",
description=(
"Filter entities by domain"
" (e.g. 'light', 'sensor')."
" Accepts a single domain or a list."
),
): vol.Any(cv.string, [cv.string]),
vol.Optional(
"area",
description="Filter entities by area name or alias (case-insensitive).",
): cv.string,
}
)
async def async_call(
self,
hass: HomeAssistant,
tool_input: ToolInput,
llm_context: LLMContext,
) -> JsonObjectType:
"""Get the current state of exposed entities."""
if llm_context.assistant is None:
# Note this doesn't happen in practice since this tool won't be
# exposed if no assistant is configured.
return {"success": False, "error": "No assistant configured"}
args = self.parameters(tool_input.tool_args)
exposed_entities = async_get_exposed_entities(
hass, llm_context.assistant, include_state=True
)
if not exposed_entities["entities"]:
return {"success": False, "error": NO_ENTITIES_PROMPT}
name_filter = args.get("name")
area_filter = args.get("area")
domain_filter = args.get("domain")
if isinstance(domain_filter, str):
domain_filter = [domain_filter]
if domain_filter is not None:
domain_filter = [
normalized_domain
for domain in domain_filter
if (normalized_domain := domain.strip().lower())
]
if name_filter or area_filter or domain_filter:
exposed_states = [
state
for entity_id in exposed_entities["entities"]
if (state := hass.states.get(entity_id)) is not None
]
match_result = intent.async_match_targets(
hass,
intent.MatchTargetsConstraints(
name=name_filter,
area_name=area_filter,
domains=domain_filter,
# This tool only returns context, so multiple entities
# sharing a name (e.g. "AC" in two areas) should all be
# returned rather than failing as an ambiguous match.
allow_duplicate_names=True,
),
states=exposed_states,
)
if not match_result.is_match:
return {
"success": False,
"error": _live_context_match_error(
match_result, name_filter, area_filter, domain_filter
),
}
matched_ids = {state.entity_id for state in match_result.states}
entities = [
info
for entity_id, info in exposed_entities["entities"].items()
if entity_id in matched_ids
]
else:
entities = list(exposed_entities["entities"].values())
prompt = [
"Live Context: An overview of the areas"
" and the devices in this smart home:",
yaml_util.dump(entities),
]
return {
"success": True,
"result": "\n".join(prompt),
}
@@ -3,7 +3,7 @@
"name": "Model Context Protocol Server",
"codeowners": ["@allenporter"],
"config_flow": true,
"dependencies": ["http", "conversation"],
"dependencies": ["http", "conversation", "llm"],
"documentation": "https://www.home-assistant.io/integrations/mcp_server",
"integration_type": "service",
"iot_class": "local_push",
+226
View File
@@ -0,0 +1,226 @@
"""LLM tools for the script integration."""
from typing import Any
import voluptuous as vol
from homeassistant.const import (
ATTR_DOMAIN,
ATTR_SERVICE,
EVENT_HOMEASSISTANT_CLOSE,
EVENT_SERVICE_REMOVED,
)
from homeassistant.core import Event, HomeAssistant, callback, split_entity_id
from homeassistant.helpers import (
area_registry as ar,
config_validation as cv,
entity_registry as er,
floor_registry as fr,
intent,
llm,
selector,
service,
)
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType
from .const import DOMAIN
ACTION_PARAMETERS_CACHE: HassKey[
dict[str, dict[str, tuple[str | None, vol.Schema]]]
] = HassKey("llm_action_parameters_cache")
async def async_setup_tools(hass: HomeAssistant) -> None:
"""Set up the script LLM tools."""
llm.async_register_tool_provider(
hass, _script_tools, apis={llm.LLM_API_ASSIST: None}
)
@callback
def _script_tools(hass: HomeAssistant, llm_context: llm.LLMContext) -> llm.LLMTools:
"""Return the script tools for the exposed scripts."""
if llm_context.assistant is None:
return llm.LLMTools(tools=[])
exposed = llm.async_get_exposed_entities(
hass, llm_context.assistant, include_state=False
)
return llm.LLMTools(
tools=[
ScriptTool(hass, script_entity_id) for script_entity_id in exposed[DOMAIN]
]
)
def _get_cached_action_parameters(
hass: HomeAssistant, domain: str, action: str
) -> tuple[str | None, vol.Schema]:
"""Get action description and schema."""
description = None
parameters = vol.Schema({})
parameters_cache = hass.data.get(ACTION_PARAMETERS_CACHE)
if parameters_cache is None:
parameters_cache = hass.data[ACTION_PARAMETERS_CACHE] = {}
@callback
def clear_cache(event: Event) -> None:
"""Clear action parameter cache on action removal."""
if (
event.data[ATTR_DOMAIN] in parameters_cache
and event.data[ATTR_SERVICE]
in parameters_cache[event.data[ATTR_DOMAIN]]
):
parameters_cache[event.data[ATTR_DOMAIN]].pop(event.data[ATTR_SERVICE])
cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
@callback
def on_homeassistant_close(event: Event) -> None:
"""Cleanup."""
cancel()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close)
if domain in parameters_cache and action in parameters_cache[domain]:
return parameters_cache[domain][action]
if action_desc := service.async_get_cached_service_description(
hass, domain, action
):
description = action_desc.get("description")
schema: dict[vol.Marker, Any] = {}
fields = action_desc.get("fields", {})
for field, config in fields.items():
field_description = config.get("description")
if not field_description:
field_description = config.get("name")
key: vol.Marker
if config.get("required"):
key = vol.Required(field, description=field_description)
else:
key = vol.Optional(field, description=field_description)
if "selector" in config:
schema[key] = selector.selector(config["selector"])
else:
schema[key] = cv.string
parameters = vol.Schema(schema)
if domain == DOMAIN:
entity_registry = er.async_get(hass)
if (
entity_id := entity_registry.async_get_entity_id(domain, domain, action)
) is not None and (
entity_entry := entity_registry.async_get(entity_id)
) is not None:
aliases = er.async_get_entity_aliases(hass, entity_entry)
if aliases:
if description:
description = description + ". Aliases: " + str(sorted(aliases))
else:
description = "Aliases: " + str(sorted(aliases))
parameters_cache.setdefault(domain, {})[action] = (description, parameters)
return description, parameters
class ActionTool(llm.Tool):
"""LLM Tool representing an action."""
def __init__(
self,
hass: HomeAssistant,
domain: str,
action: str,
) -> None:
"""Init the class."""
self._domain = domain
self._action = action
self.name = f"{domain}__{action}"
# Note: _get_cached_action_parameters only works for services which
# add their description directly to the service description cache.
# This is not the case for most services, but it is for scripts.
# If we want to use `ActionTool` for services other than scripts, we
# need to add a coroutine function to fetch the non-cached description
# and schema.
self.description, self.parameters = _get_cached_action_parameters(
hass, domain, action
)
async def async_call(
self,
hass: HomeAssistant,
tool_input: llm.ToolInput,
llm_context: llm.LLMContext,
) -> JsonObjectType:
"""Call the action."""
for field, validator in self.parameters.schema.items():
if field not in tool_input.tool_args:
continue
if isinstance(validator, selector.AreaSelector):
area_reg = ar.async_get(hass)
if validator.config.get("multiple"):
areas: list[ar.AreaEntry] = []
for area in tool_input.tool_args[field]:
areas.extend(intent.find_areas(area, area_reg))
tool_input.tool_args[field] = list({area.id for area in areas})
else:
area = tool_input.tool_args[field]
area = list(intent.find_areas(area, area_reg))[0].id
tool_input.tool_args[field] = area
elif isinstance(validator, selector.FloorSelector):
floor_reg = fr.async_get(hass)
if validator.config.get("multiple"):
floors: list[fr.FloorEntry] = []
for floor in tool_input.tool_args[field]:
floors.extend(intent.find_floors(floor, floor_reg))
tool_input.tool_args[field] = list(
{floor.floor_id for floor in floors}
)
else:
floor = tool_input.tool_args[field]
floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
tool_input.tool_args[field] = floor
result = await hass.services.async_call(
self._domain,
self._action,
tool_input.tool_args,
context=llm_context.context,
blocking=True,
return_response=True,
)
return {"success": True, "result": result}
class ScriptTool(ActionTool):
"""LLM Tool representing a Script."""
def __init__(
self,
hass: HomeAssistant,
script_entity_id: str,
) -> None:
"""Init the class."""
script_name = split_entity_id(script_entity_id)[1]
action = script_name
entity_registry = er.async_get(hass)
entity_entry = entity_registry.async_get(script_entity_id)
if entity_entry and entity_entry.unique_id:
action = entity_entry.unique_id
super().__init__(hass, DOMAIN, action)
self.name = script_name
if self.name[0].isdigit():
self.name = "_" + self.name
+104
View File
@@ -0,0 +1,104 @@
"""LLM tools for the todo integration."""
from typing import Any, cast
import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import intent, llm
from homeassistant.util.json import JsonObjectType
from .const import DOMAIN, TodoServices
async def async_setup_tools(hass: HomeAssistant) -> None:
"""Set up the todo LLM tools."""
llm.async_register_tool_provider(hass, _todo_tools, apis={llm.LLM_API_ASSIST: None})
@callback
def _todo_tools(hass: HomeAssistant, llm_context: llm.LLMContext) -> llm.LLMTools:
"""Return the todo tools for the exposed to-do lists."""
if llm_context.assistant is None:
return llm.LLMTools(tools=[])
exposed = llm.async_get_exposed_entities(
hass, llm_context.assistant, include_state=False
)
names = []
for info in exposed["entities"].values():
if info["domain"] != DOMAIN:
continue
names.extend(info["names"].split(", "))
if not names:
return llm.LLMTools(tools=[])
return llm.LLMTools(tools=[TodoGetItemsTool(names)])
class TodoGetItemsTool(llm.Tool):
"""LLM Tool allowing querying a to-do list."""
name = "todo_get_items"
description = (
"Query a to-do list to find out what items are on it. "
"Use this to answer questions like "
"'What's on my task list?' or "
"'Read my grocery list'. "
"Filters items by status (needs_action, completed, all)."
)
def __init__(self, todo_lists: list[str]) -> None:
"""Init the get items tool."""
self.parameters = vol.Schema(
{
vol.Required("todo_list"): vol.In(todo_lists),
vol.Optional(
"status",
description=(
"Filter returned items by status,"
" by default returns incomplete"
" items"
),
default="needs_action",
): vol.In(["needs_action", "completed", "all"]),
}
)
async def async_call(
self,
hass: HomeAssistant,
tool_input: llm.ToolInput,
llm_context: llm.LLMContext,
) -> JsonObjectType:
"""Query a to-do list."""
data = self.parameters(tool_input.tool_args)
result = intent.async_match_targets(
hass,
intent.MatchTargetsConstraints(
name=data["todo_list"],
domains=[DOMAIN],
assistant=llm_context.assistant,
),
)
if not result.is_match:
return {"success": False, "error": "To-do list not found"}
entity_id = result.states[0].entity_id
service_data: dict[str, Any] = {"entity_id": entity_id}
if status := data.get("status"):
if status == "all":
service_data["status"] = ["needs_action", "completed"]
else:
service_data["status"] = [status]
service_result = await hass.services.async_call(
DOMAIN,
TodoServices.GET_ITEMS,
service_data,
context=llm_context.context,
blocking=True,
return_response=True,
)
if not service_result:
return {"success": False, "error": "To-do list not found"}
items = cast(dict, service_result)[entity_id]["items"]
return {"success": True, "result": items}
+119 -627
View File
@@ -3,35 +3,20 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field as dc_field
from datetime import timedelta
from decimal import Decimal
from enum import Enum
from functools import cache, partial
from operator import attrgetter
from typing import Any, cast
from typing import Any
import slugify as unicode_slug
import voluptuous as vol
from voluptuous_openapi import UNSUPPORTED, convert
from homeassistant.components.calendar import (
DOMAIN as CALENDAR_DOMAIN,
SERVICE_GET_EVENTS,
)
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
from homeassistant.components.calendar import DOMAIN as CALENDAR_DOMAIN
from homeassistant.components.homeassistant import async_should_expose
from homeassistant.components.intent import async_device_supports_timers
from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN
from homeassistant.components.sensor import async_rounded_state
from homeassistant.components.todo import DOMAIN as TODO_DOMAIN, TodoServices
from homeassistant.components.weather import INTENT_GET_WEATHER
from homeassistant.const import (
ATTR_DOMAIN,
ATTR_SERVICE,
EVENT_HOMEASSISTANT_CLOSE,
EVENT_SERVICE_REMOVED,
)
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import dt as dt_util, yaml as yaml_util
from homeassistant.util.hass_dict import HassKey
@@ -46,15 +31,9 @@ from . import (
floor_registry as fr,
intent,
selector,
service,
)
from .singleton import singleton
ACTION_PARAMETERS_CACHE: HassKey[
dict[str, dict[str, tuple[str | None, vol.Schema]]]
] = HassKey("llm_action_parameters_cache")
LLM_API_ASSIST = "assist"
DATE_TIME_PROMPT = (
@@ -72,43 +51,6 @@ NO_ENTITIES_PROMPT = (
"to their voice assistant in Home Assistant."
)
DEVICE_CONTROL_TOOL_USAGE_PROMPT = (
"When controlling Home Assistant always call the intent tools. "
"Use HassTurnOn to lock and HassTurnOff to unlock a lock. "
"When controlling a device, prefer passing just name and domain. "
"When controlling an area, prefer passing just area name and domain."
)
DYNAMIC_CONTEXT_PROMPT = (
"You ARE equipped to answer questions about the"
" current state of\n"
"the home using the `GetLiveContext` tool."
" This is a primary function."
" Do not state you lack the\n"
"functionality if the question requires live data.\n"
"If the user asks about device existence/type"
' (e.g., "Do I have lights in the bedroom?"):'
" Answer\n"
"from the static context below.\n"
"If the user asks about the CURRENT state, value,"
' or mode (e.g., "Is the lock locked?",\n'
'"Is the fan on?",'
' "What mode is the thermostat in?",'
' "What is the temperature outside?"):\n'
" 1. Recognize this requires live data.\n"
" 2. You MUST call `GetLiveContext`."
" This tool will provide the needed real-time"
" information (like temperature from the local"
" weather, lock status, etc.).\n"
" 3. Use the tool's response** to answer the"
" user accurately"
' (e.g., "The temperature outside is'
' [value from tool].").\n'
"For general knowledge questions not about the"
" home: Answer truthfully from internal"
" knowledge.\n"
)
@callback
def async_render_no_api_prompt(hass: HomeAssistant) -> str:
@@ -228,6 +170,95 @@ class Tool:
return f"<{self.__class__.__name__} - {self.name}>"
@dataclass(slots=True)
class LLMTools:
"""Tools and an optional prompt fragment contributed by a provider."""
tools: list[Tool]
prompt: str | None = None
type LLMToolProvider = Callable[[HomeAssistant, LLMContext], LLMTools]
@dataclass(slots=True)
class _ToolProviderRegistration:
"""A registered tool provider and its importance within an API surface."""
provider: LLMToolProvider
importance: int | None
_TOOL_PROVIDERS: HassKey[dict[str, list[_ToolProviderRegistration]]] = HassKey(
"llm_tool_providers"
)
@callback
def async_register_tool_provider(
hass: HomeAssistant,
provider: LLMToolProvider,
*,
apis: dict[str, int | None],
) -> Callable[[], None]:
"""Register a provider that contributes tools to one or more LLM APIs.
The provider is evaluated per request with the ``LLMContext`` and returns
the tools (and an optional prompt fragment) to expose. ``apis`` maps each
API id the provider contributes to onto an importance; the importance is
not used yet, pass ``None``.
"""
registrations = hass.data.setdefault(_TOOL_PROVIDERS, {})
registered: list[tuple[str, _ToolProviderRegistration]] = []
for api_id, importance in apis.items():
registration = _ToolProviderRegistration(provider, importance)
registrations.setdefault(api_id, []).append(registration)
registered.append((api_id, registration))
@callback
def unregister() -> None:
"""Unregister the tool provider."""
for api_id, registration in registered:
registrations[api_id].remove(registration)
return unregister
@callback
def async_register_tool(
hass: HomeAssistant,
tool: Tool,
*,
apis: dict[str, int | None],
) -> Callable[[], None]:
"""Register a single static tool with one or more LLM APIs."""
@callback
def _provider(_hass: HomeAssistant, _llm_context: LLMContext) -> LLMTools:
return LLMTools(tools=[tool])
return async_register_tool_provider(hass, _provider, apis=apis)
@callback
def _async_get_registered_tools(
hass: HomeAssistant, api_id: str, llm_context: LLMContext
) -> LLMTools:
"""Return the tools and merged prompt from all providers for an API."""
registrations = hass.data.get(_TOOL_PROVIDERS)
if not registrations or not (records := registrations.get(api_id)):
return LLMTools(tools=[])
tools: list[Tool] = []
prompts: list[str] = []
for registration in records:
result = registration.provider(hass, llm_context)
tools.extend(result.tools)
if result.prompt:
prompts.append(result.prompt)
return LLMTools(tools=tools, prompt="\n".join(prompts) if prompts else None)
@dataclass
class APIInstance:
"""Instance of an API to be used by an LLM."""
@@ -457,19 +488,6 @@ class MergedAPI(API):
class AssistAPI(API):
"""API exposing Assist API to LLMs."""
IGNORE_INTENTS = {
intent.INTENT_GET_TEMPERATURE,
INTENT_GET_WEATHER,
INTENT_OPEN_COVER, # deprecated
INTENT_CLOSE_COVER, # deprecated
intent.INTENT_GET_STATE,
intent.INTENT_NEVERMIND,
intent.INTENT_TOGGLE,
intent.INTENT_GET_CURRENT_DATE,
intent.INTENT_GET_CURRENT_TIME,
intent.INTENT_RESPOND,
}
def __init__(self, hass: HomeAssistant) -> None:
"""Init the class."""
super().__init__(
@@ -477,9 +495,6 @@ class AssistAPI(API):
id=LLM_API_ASSIST,
name="Assist",
)
self.cached_slugify = cache(
partial(unicode_slug.slugify, separator="_", lowercase=False)
)
async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
"""Return the instance of the API."""
@@ -490,11 +505,20 @@ class AssistAPI(API):
else:
exposed_entities = None
registered = _async_get_registered_tools(self.hass, LLM_API_ASSIST, llm_context)
api_prompt = self._async_get_api_prompt(llm_context, exposed_entities)
if registered.prompt:
api_prompt = f"{api_prompt}\n{registered.prompt}"
tools = self._async_get_tools(llm_context, exposed_entities)
tools.extend(registered.tools)
return APIInstance(
api=self,
api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
api_prompt=api_prompt,
llm_context=llm_context,
tools=self._async_get_tools(llm_context, exposed_entities),
tools=tools,
custom_serializer=selector_serializer,
)
@@ -507,24 +531,13 @@ class AssistAPI(API):
# Collect all parts, filtering out any None values
prompt_parts = [
DEVICE_CONTROL_TOOL_USAGE_PROMPT,
DYNAMIC_CONTEXT_PROMPT,
*self._async_get_exposed_entities_prompt(exposed_entities),
self._async_get_voice_satellite_area_prompt(llm_context),
self._async_get_no_timer_prompt(llm_context),
]
# Filter out None and empty strings before joining
return "\n".join([part for part in prompt_parts if part])
@callback
def _async_get_no_timer_prompt(self, llm_context: LLMContext) -> str | None:
if not llm_context.device_id or not async_device_supports_timers(
self.hass, llm_context.device_id
):
return "This device is not able to start timers."
return None
@callback
def _async_get_voice_satellite_area_prompt(self, llm_context: LLMContext) -> str:
"""Return the area prompt for the voice satellite."""
@@ -579,70 +592,7 @@ class AssistAPI(API):
self, llm_context: LLMContext, exposed_entities: dict | None
) -> list[Tool]:
"""Return a list of LLM tools."""
ignore_intents = self.IGNORE_INTENTS
if not llm_context.device_id or not async_device_supports_timers(
self.hass, llm_context.device_id
):
ignore_intents = ignore_intents | {
intent.INTENT_START_TIMER,
intent.INTENT_CANCEL_TIMER,
intent.INTENT_INCREASE_TIMER,
intent.INTENT_DECREASE_TIMER,
intent.INTENT_PAUSE_TIMER,
intent.INTENT_UNPAUSE_TIMER,
intent.INTENT_TIMER_STATUS,
}
intent_handlers = [
intent_handler
for intent_handler in intent.async_get(self.hass)
if intent_handler.intent_type not in ignore_intents
]
exposed_domains: set[str] | None = None
if exposed_entities is not None:
exposed_domains = {
info["domain"] for info in exposed_entities["entities"].values()
}
intent_handlers = [
intent_handler
for intent_handler in intent_handlers
if intent_handler.platforms is None
or intent_handler.platforms & exposed_domains
]
tools: list[Tool] = [
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
for intent_handler in intent_handlers
]
tools.append(GetDateTimeTool())
if exposed_entities:
if exposed_entities[CALENDAR_DOMAIN]:
names = []
for info in exposed_entities[CALENDAR_DOMAIN].values():
names.extend(info["names"].split(", "))
tools.append(CalendarGetEventsTool(names))
if exposed_domains is not None and TODO_DOMAIN in exposed_domains:
names = []
for info in exposed_entities["entities"].values():
if info["domain"] != TODO_DOMAIN:
continue
names.extend(info["names"].split(", "))
tools.append(TodoGetItemsTool(names))
tools.extend(
ScriptTool(self.hass, script_entity_id)
for script_entity_id in exposed_entities[SCRIPT_DOMAIN]
)
if exposed_domains:
tools.append(GetLiveContextTool())
return tools
return []
def _get_exposed_entities(
@@ -755,6 +705,21 @@ def _get_exposed_entities(
return data
@callback
def async_get_exposed_entities(
hass: HomeAssistant,
assistant: str,
*,
include_state: bool = False,
) -> dict[str, dict[str, dict[str, Any]]]:
"""Get exposed entities for a tool provider.
Splits out calendars and scripts. Tool providers can use this to reproduce
the exact entity names that AssistAPI feeds to its built-in tools.
"""
return _get_exposed_entities(hass, assistant, include_state=include_state)
def selector_serializer(schema: Any) -> Any: # noqa: C901
"""Convert selectors into OpenAPI schema."""
if not isinstance(schema, selector.Selector):
@@ -893,476 +858,3 @@ def selector_serializer(schema: Any) -> Any: # noqa: C901
return {"type": "array", "items": {"type": "string"}}
return {"type": "string"}
def _get_cached_action_parameters(
hass: HomeAssistant, domain: str, action: str
) -> tuple[str | None, vol.Schema]:
"""Get action description and schema."""
description = None
parameters = vol.Schema({})
parameters_cache = hass.data.get(ACTION_PARAMETERS_CACHE)
if parameters_cache is None:
parameters_cache = hass.data[ACTION_PARAMETERS_CACHE] = {}
@callback
def clear_cache(event: Event) -> None:
"""Clear action parameter cache on action removal."""
if (
event.data[ATTR_DOMAIN] in parameters_cache
and event.data[ATTR_SERVICE]
in parameters_cache[event.data[ATTR_DOMAIN]]
):
parameters_cache[event.data[ATTR_DOMAIN]].pop(event.data[ATTR_SERVICE])
cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
@callback
def on_homeassistant_close(event: Event) -> None:
"""Cleanup."""
cancel()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close)
if domain in parameters_cache and action in parameters_cache[domain]:
return parameters_cache[domain][action]
if action_desc := service.async_get_cached_service_description(
hass, domain, action
):
description = action_desc.get("description")
schema: dict[vol.Marker, Any] = {}
fields = action_desc.get("fields", {})
for field, config in fields.items():
field_description = config.get("description")
if not field_description:
field_description = config.get("name")
key: vol.Marker
if config.get("required"):
key = vol.Required(field, description=field_description)
else:
key = vol.Optional(field, description=field_description)
if "selector" in config:
schema[key] = selector.selector(config["selector"])
else:
schema[key] = cv.string
parameters = vol.Schema(schema)
if domain == SCRIPT_DOMAIN:
entity_registry = er.async_get(hass)
if (
entity_id := entity_registry.async_get_entity_id(domain, domain, action)
) is not None and (
entity_entry := entity_registry.async_get(entity_id)
) is not None:
aliases = er.async_get_entity_aliases(hass, entity_entry)
if aliases:
if description:
description = description + ". Aliases: " + str(sorted(aliases))
else:
description = "Aliases: " + str(sorted(aliases))
parameters_cache.setdefault(domain, {})[action] = (description, parameters)
return description, parameters
class ActionTool(Tool):
"""LLM Tool representing an action."""
def __init__(
self,
hass: HomeAssistant,
domain: str,
action: str,
) -> None:
"""Init the class."""
self._domain = domain
self._action = action
self.name = f"{domain}__{action}"
# Note: _get_cached_action_parameters only works for services which
# add their description directly to the service description cache.
# This is not the case for most services, but it is for scripts.
# If we want to use `ActionTool` for services other than scripts, we
# need to add a coroutine function to fetch the non-cached description
# and schema.
self.description, self.parameters = _get_cached_action_parameters(
hass, domain, action
)
async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType:
"""Call the action."""
for field, validator in self.parameters.schema.items():
if field not in tool_input.tool_args:
continue
if isinstance(validator, selector.AreaSelector):
area_reg = ar.async_get(hass)
if validator.config.get("multiple"):
areas: list[ar.AreaEntry] = []
for area in tool_input.tool_args[field]:
areas.extend(intent.find_areas(area, area_reg))
tool_input.tool_args[field] = list({area.id for area in areas})
else:
area = tool_input.tool_args[field]
area = list(intent.find_areas(area, area_reg))[0].id
tool_input.tool_args[field] = area
elif isinstance(validator, selector.FloorSelector):
floor_reg = fr.async_get(hass)
if validator.config.get("multiple"):
floors: list[fr.FloorEntry] = []
for floor in tool_input.tool_args[field]:
floors.extend(intent.find_floors(floor, floor_reg))
tool_input.tool_args[field] = list(
{floor.floor_id for floor in floors}
)
else:
floor = tool_input.tool_args[field]
floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
tool_input.tool_args[field] = floor
result = await hass.services.async_call(
self._domain,
self._action,
tool_input.tool_args,
context=llm_context.context,
blocking=True,
return_response=True,
)
return {"success": True, "result": result}
class ScriptTool(ActionTool):
"""LLM Tool representing a Script."""
def __init__(
self,
hass: HomeAssistant,
script_entity_id: str,
) -> None:
"""Init the class."""
script_name = split_entity_id(script_entity_id)[1]
action = script_name
entity_registry = er.async_get(hass)
entity_entry = entity_registry.async_get(script_entity_id)
if entity_entry and entity_entry.unique_id:
action = entity_entry.unique_id
super().__init__(hass, SCRIPT_DOMAIN, action)
self.name = script_name
if self.name[0].isdigit():
self.name = "_" + self.name
class CalendarGetEventsTool(Tool):
"""LLM Tool allowing querying a calendar."""
name = "calendar_get_events"
description = (
"Get events from a calendar. "
"When asked if something happens, search the whole week. "
"Results are RFC 5545 which means 'end' is exclusive."
)
def __init__(self, calendars: list[str]) -> None:
"""Init the get events tool."""
self.parameters = vol.Schema(
{
vol.Required("calendar"): vol.In(calendars),
vol.Required("range"): vol.In(["today", "week"]),
}
)
async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType:
"""Query a calendar."""
data = self.parameters(tool_input.tool_args)
result = intent.async_match_targets(
hass,
intent.MatchTargetsConstraints(
name=data["calendar"],
domains=[CALENDAR_DOMAIN],
assistant=llm_context.assistant,
),
)
if not result.is_match:
return {"success": False, "error": "Calendar not found"}
entity_id = result.states[0].entity_id
if data["range"] == "today":
start = dt_util.now()
end = dt_util.start_of_local_day() + timedelta(days=1)
elif data["range"] == "week":
start = dt_util.now()
end = dt_util.start_of_local_day() + timedelta(days=7)
service_data = {
"entity_id": entity_id,
"start_date_time": start.isoformat(),
"end_date_time": end.isoformat(),
}
service_result = await hass.services.async_call(
CALENDAR_DOMAIN,
SERVICE_GET_EVENTS,
service_data,
context=llm_context.context,
blocking=True,
return_response=True,
)
events = [
event if "T" in event["start"] else {**event, "all_day": True}
for event in cast(dict, service_result)[entity_id]["events"]
]
return {"success": True, "result": events}
class TodoGetItemsTool(Tool):
"""LLM Tool allowing querying a to-do list."""
name = "todo_get_items"
description = (
"Query a to-do list to find out what items are on it. "
"Use this to answer questions like "
"'What's on my task list?' or "
"'Read my grocery list'. "
"Filters items by status (needs_action, completed, all)."
)
def __init__(self, todo_lists: list[str]) -> None:
"""Init the get items tool."""
self.parameters = vol.Schema(
{
vol.Required("todo_list"): vol.In(todo_lists),
vol.Optional(
"status",
description=(
"Filter returned items by status,"
" by default returns incomplete"
" items"
),
default="needs_action",
): vol.In(["needs_action", "completed", "all"]),
}
)
async def async_call(
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
) -> JsonObjectType:
"""Query a to-do list."""
data = self.parameters(tool_input.tool_args)
result = intent.async_match_targets(
hass,
intent.MatchTargetsConstraints(
name=data["todo_list"],
domains=[TODO_DOMAIN],
assistant=llm_context.assistant,
),
)
if not result.is_match:
return {"success": False, "error": "To-do list not found"}
entity_id = result.states[0].entity_id
service_data: dict[str, Any] = {"entity_id": entity_id}
if status := data.get("status"):
if status == "all":
service_data["status"] = ["needs_action", "completed"]
else:
service_data["status"] = [status]
service_result = await hass.services.async_call(
TODO_DOMAIN,
TodoServices.GET_ITEMS,
service_data,
context=llm_context.context,
blocking=True,
return_response=True,
)
if not service_result:
return {"success": False, "error": "To-do list not found"}
items = cast(dict, service_result)[entity_id]["items"]
return {"success": True, "result": items}
def _live_context_match_error(
match_result: intent.MatchTargetsResult,
name_filter: str | None,
area_filter: str | None,
domain_filter: list[str] | None,
) -> str:
"""Build an actionable error message for a failed GetLiveContext match."""
reason = match_result.no_match_reason
if reason is intent.MatchFailedReason.INVALID_AREA:
return f"Area '{match_result.no_match_name}' does not exist"
if reason is intent.MatchFailedReason.NAME:
return f"No exposed entities matched name '{name_filter}'"
if reason is intent.MatchFailedReason.AREA:
return f"No exposed entities found in area '{area_filter}'"
if reason is intent.MatchFailedReason.DOMAIN:
domains = ", ".join(domain_filter) if domain_filter else ""
return f"No exposed entities found in domain(s): {domains}"
return "No entities matched the provided filter"
class GetLiveContextTool(Tool):
"""Tool for getting the current state of exposed entities.
This returns state for all entities that have been exposed to
the assistant. This is different than the GetState intent, which
returns state for entities based on intent parameters.
"""
name = "GetLiveContext"
description = (
"Provides real-time information about the"
" CURRENT state, value, or mode of devices,"
" sensors, entities, or areas. "
"Use this tool for: "
"1. Answering questions about current"
" conditions (e.g., 'Is the light on?'). "
"2. As the first step in conditional actions"
" (e.g., 'If the weather is rainy, turn off"
" sprinklers' requires checking the weather"
" first). "
"You may filter for devices by name, domain,"
" and area, including combining those"
" filters. "
"Prefer filtering by domain when searching"
" for multiple devices of the same type."
)
parameters = vol.Schema(
{
vol.Optional(
"name",
description="Filter entities by name or alias (case-insensitive).",
): cv.string,
vol.Optional(
"domain",
description=(
"Filter entities by domain"
" (e.g. 'light', 'sensor')."
" Accepts a single domain or a list."
),
): vol.Any(cv.string, [cv.string]),
vol.Optional(
"area",
description="Filter entities by area name or alias (case-insensitive).",
): cv.string,
}
)
async def async_call(
self,
hass: HomeAssistant,
tool_input: ToolInput,
llm_context: LLMContext,
) -> JsonObjectType:
"""Get the current state of exposed entities."""
if llm_context.assistant is None:
# Note this doesn't happen in practice since this tool won't be
# exposed if no assistant is configured.
return {"success": False, "error": "No assistant configured"}
args = self.parameters(tool_input.tool_args)
exposed_entities = _get_exposed_entities(hass, llm_context.assistant)
if not exposed_entities["entities"]:
return {"success": False, "error": NO_ENTITIES_PROMPT}
name_filter = args.get("name")
area_filter = args.get("area")
domain_filter = args.get("domain")
if isinstance(domain_filter, str):
domain_filter = [domain_filter]
if domain_filter is not None:
domain_filter = [
normalized_domain
for domain in domain_filter
if (normalized_domain := domain.strip().lower())
]
if name_filter or area_filter or domain_filter:
exposed_states = [
state
for entity_id in exposed_entities["entities"]
if (state := hass.states.get(entity_id)) is not None
]
match_result = intent.async_match_targets(
hass,
intent.MatchTargetsConstraints(
name=name_filter,
area_name=area_filter,
domains=domain_filter,
# This tool only returns context, so multiple entities
# sharing a name (e.g. "AC" in two areas) should all be
# returned rather than failing as an ambiguous match.
allow_duplicate_names=True,
),
states=exposed_states,
)
if not match_result.is_match:
return {
"success": False,
"error": _live_context_match_error(
match_result, name_filter, area_filter, domain_filter
),
}
matched_ids = {state.entity_id for state in match_result.states}
entities = [
info
for entity_id, info in exposed_entities["entities"].items()
if entity_id in matched_ids
]
else:
entities = list(exposed_entities["entities"].values())
prompt = [
"Live Context: An overview of the areas"
" and the devices in this smart home:",
yaml_util.dump(entities),
]
return {
"success": True,
"result": "\n".join(prompt),
}
class GetDateTimeTool(Tool):
"""Tool for getting the current date and time."""
name = "GetDateTime"
description = "Provides the current date and time."
async def async_call(
self,
hass: HomeAssistant,
tool_input: ToolInput,
llm_context: LLMContext,
) -> JsonObjectType:
"""Get the current date and time."""
now = dt_util.now()
return {
"success": True,
"result": {
"date": now.strftime("%Y-%m-%d"),
"time": now.strftime("%H:%M:%S"),
"timezone": now.strftime("%Z"),
"weekday": now.strftime("%A"),
},
}
+1
View File
@@ -2101,6 +2101,7 @@ NO_QUALITY_SCALE = [
"intent_script",
"intent",
"labs",
"llm",
"logbook",
"logger",
"lovelace",
@@ -843,9 +843,11 @@ async def test_redacted_thinking(
assert chat_log.content[1:] == snapshot
@patch("homeassistant.helpers.llm._async_get_registered_tools")
@patch("homeassistant.components.anthropic.entity.llm.AssistAPI._async_get_tools")
async def test_extended_thinking_tool_call(
mock_get_tools,
mock_registered_tools,
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
@@ -875,6 +877,9 @@ async def test_extended_thinking_tool_call(
mock_tool.async_call.return_value = "Test response"
mock_get_tools.return_value = [mock_tool]
# The Assist API now also sources tools from the llm integration's registry;
# empty it so this test's tool set is just the mock tool.
mock_registered_tools.return_value = llm.LLMTools(tools=[])
mock_create_stream.return_value = [
(
@@ -28,6 +28,7 @@ from homeassistant.components.conversation.chat_log import (
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, llm
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
from tests.common import async_fire_time_changed
@@ -168,6 +169,11 @@ async def test_dynamic_time_injection(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test that dynamic time injection works correctly."""
# The Assist API provides the GetDateTime tool via the llm integration, which
# suppresses the dynamic time prompt (case 3 below).
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "llm", {})
await hass.async_block_till_done()
class MyAPI(llm.API):
"""Test API."""
+1
View File
@@ -0,0 +1 @@
"""Tests for the LLM integration."""
+25
View File
@@ -0,0 +1,25 @@
"""Tests for the LLM integration."""
from unittest.mock import AsyncMock, Mock
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from tests.common import mock_platform
async def test_setup(hass: HomeAssistant) -> None:
"""Test the integration sets up."""
assert await async_setup_component(hass, "llm", {})
async def test_tool_platform_discovery(hass: HomeAssistant) -> None:
"""Test that an integration's llm tools platform is set up."""
platform = Mock(async_setup_tools=AsyncMock())
mock_platform(hass, "test.llm", platform)
hass.config.components.add("test")
assert await async_setup_component(hass, "llm", {})
await hass.async_block_till_done()
platform.async_setup_tools.assert_awaited_once_with(hass)
+566
View File
@@ -0,0 +1,566 @@
# serializer version: 1
# name: test_assist_api_snapshot[prompt]
list([
' 1. Recognize this requires live data.',
' 2. You MUST call `GetLiveContext`. This tool will provide the needed real-time information (like temperature from the local weather, lock status, etc.).',
' 3. Use the tool\'s response** to answer the user accurately (e.g., "The temperature outside is [value from tool].").',
' domain: light',
' domain: todo',
'"Is the fan on?", "What mode is the thermostat in?", "What is the temperature outside?"):',
'- names: Kitchen',
'- names: Shopping',
'For general knowledge questions not about the home: Answer truthfully from internal knowledge.',
'If the user asks about device existence/type (e.g., "Do I have lights in the bedroom?"): Answer',
'If the user asks about the CURRENT state, value, or mode (e.g., "Is the lock locked?",',
'Static Context: An overview of the areas and the devices in this smart home:',
'When controlling Home Assistant always call the intent tools. Use HassTurnOn to lock and HassTurnOff to unlock a lock. When controlling a device, prefer passing just name and domain. When controlling an area, prefer passing just area name and domain.',
'You ARE equipped to answer questions about the current state of',
"You are in area Test Area and all generic commands like 'turn on the lights' should target this area.",
'from the static context below.',
'functionality if the question requires live data.',
'the home using the `GetLiveContext` tool. This is a primary function. Do not state you lack the',
])
# ---
# name: test_assist_api_snapshot[tools]
dict({
'GetDateTime': dict({
'description': 'Provides the current date and time.',
'parameters': dict({
'properties': dict({
}),
'required': list([
]),
'type': 'object',
}),
}),
'GetLiveContext': dict({
'description': "Provides real-time information about the CURRENT state, value, or mode of devices, sensors, entities, or areas. Use this tool for: 1. Answering questions about current conditions (e.g., 'Is the light on?'). 2. As the first step in conditional actions (e.g., 'If the weather is rainy, turn off sprinklers' requires checking the weather first). You may filter for devices by name, domain, and area, including combining those filters. Prefer filtering by domain when searching for multiple devices of the same type.",
'parameters': dict({
'properties': dict({
'area': dict({
'description': 'Filter entities by area name or alias (case-insensitive).',
'type': 'string',
}),
'domain': dict({
'anyOf': list([
dict({
}),
dict({
'items': dict({
'type': 'string',
}),
'type': 'array',
}),
]),
'description': "Filter entities by domain (e.g. 'light', 'sensor'). Accepts a single domain or a list.",
}),
'name': dict({
'description': 'Filter entities by name or alias (case-insensitive).',
'type': 'string',
}),
}),
'required': list([
]),
'type': 'object',
}),
}),
'HassCancelAllTimers': dict({
'description': 'Cancels all timers',
'parameters': dict({
'properties': dict({
'area': dict({
'type': 'string',
}),
}),
'required': list([
]),
'type': 'object',
}),
}),
'HassCancelTimer': dict({
'description': 'Cancels a timer',
'parameters': dict({
'properties': dict({
'area': dict({
'type': 'string',
}),
'name': dict({
'type': 'string',
}),
'start_hours': dict({
'minimum': 0,
'type': 'integer',
}),
'start_minutes': dict({
'minimum': 0,
'type': 'integer',
}),
'start_seconds': dict({
'minimum': 0,
'type': 'integer',
}),
}),
'required': list([
]),
'type': 'object',
}),
}),
'HassDecreaseTimer': dict({
'description': 'Removes time from a timer',
'parameters': dict({
'anyOf': list([
dict({
'required': list([
'hours',
]),
}),
dict({
'required': list([
'minutes',
]),
}),
dict({
'required': list([
'seconds',
]),
}),
]),
'properties': dict({
'area': dict({
'type': 'string',
}),
'hours': dict({
'minimum': 0,
'type': 'integer',
}),
'minutes': dict({
'minimum': 0,
'type': 'integer',
}),
'name': dict({
'type': 'string',
}),
'seconds': dict({
'minimum': 0,
'type': 'integer',
}),
'start_hours': dict({
'minimum': 0,
'type': 'integer',
}),
'start_minutes': dict({
'minimum': 0,
'type': 'integer',
}),
'start_seconds': dict({
'minimum': 0,
'type': 'integer',
}),
}),
'required': list([
]),
'type': 'object',
}),
}),
'HassIncreaseTimer': dict({
'description': 'Adds more time to a timer',
'parameters': dict({
'properties': dict({
'area': dict({
'type': 'string',
}),
'hours': dict({
'minimum': 0,
'type': 'integer',
}),
'minutes': dict({
'minimum': 0,
'type': 'integer',
}),
'name': dict({
'type': 'string',
}),
'seconds': dict({
'minimum': 0,
'type': 'integer',
}),
'start_hours': dict({
'minimum': 0,
'type': 'integer',
}),
'start_minutes': dict({
'minimum': 0,
'type': 'integer',
}),
'start_seconds': dict({
'minimum': 0,
'type': 'integer',
}),
}),
'required': list([
]),
'type': 'object',
}),
}),
'HassListAddItem': dict({
'description': 'Add item to a todo list',
'parameters': dict({
'properties': dict({
'item': dict({
'type': 'string',
}),
'name': dict({
'type': 'string',
}),
}),
'required': list([
'item',
'name',
]),
'type': 'object',
}),
}),
'HassListCompleteItem': dict({
'description': 'Complete item on a todo list',
'parameters': dict({
'properties': dict({
'item': dict({
'type': 'string',
}),
'name': dict({
'type': 'string',
}),
}),
'required': list([
'item',
'name',
]),
'type': 'object',
}),
}),
'HassListRemoveItem': dict({
'description': 'Remove one or more items from a todo list',
'parameters': dict({
'properties': dict({
'item': dict({
'type': 'string',
}),
'name': dict({
'type': 'string',
}),
}),
'required': list([
'item',
'name',
]),
'type': 'object',
}),
}),
'HassPauseTimer': dict({
'description': 'Pauses a running timer',
'parameters': dict({
'properties': dict({
'area': dict({
'type': 'string',
}),
'name': dict({
'type': 'string',
}),
'start_hours': dict({
'minimum': 0,
'type': 'integer',
}),
'start_minutes': dict({
'minimum': 0,
'type': 'integer',
}),
'start_seconds': dict({
'minimum': 0,
'type': 'integer',
}),
}),
'required': list([
]),
'type': 'object',
}),
}),
'HassStartTimer': dict({
'description': 'Starts a new timer',
'parameters': dict({
'anyOf': list([
dict({
'required': list([
'hours',
]),
}),
dict({
'required': list([
'minutes',
]),
}),
dict({
'required': list([
'seconds',
]),
}),
]),
'properties': dict({
'conversation_command': dict({
'type': 'string',
}),
'hours': dict({
'minimum': 0,
'type': 'integer',
}),
'minutes': dict({
'minimum': 0,
'type': 'integer',
}),
'name': dict({
'type': 'string',
}),
'seconds': dict({
'minimum': 0,
'type': 'integer',
}),
}),
'required': list([
]),
'type': 'object',
}),
}),
'HassTimerStatus': dict({
'description': 'Reports the current status of timers',
'parameters': dict({
'properties': dict({
'area': dict({
'type': 'string',
}),
'name': dict({
'type': 'string',
}),
'start_hours': dict({
'minimum': 0,
'type': 'integer',
}),
'start_minutes': dict({
'minimum': 0,
'type': 'integer',
}),
'start_seconds': dict({
'minimum': 0,
'type': 'integer',
}),
}),
'required': list([
]),
'type': 'object',
}),
}),
'HassTurnOff': dict({
'description': "Turns off/closes a device or entity. For locks, this performs an 'unlock' action. Use for requests like 'turn off', 'deactivate', 'disable', or 'unlock'.",
'parameters': dict({
'properties': dict({
'area': dict({
'type': 'string',
}),
'device_class': dict({
'items': dict({
'enum': list([
'awning',
'blind',
'curtain',
'damper',
'door',
'garage',
'gas',
'gate',
'identify',
'outlet',
'projector',
'receiver',
'restart',
'shade',
'shutter',
'speaker',
'switch',
'tv',
'update',
'water',
'window',
]),
'type': 'string',
}),
'type': 'array',
}),
'domain': dict({
'items': dict({
'type': 'string',
}),
'type': 'array',
}),
'floor': dict({
'type': 'string',
}),
'name': dict({
'type': 'string',
}),
}),
'required': list([
]),
'type': 'object',
}),
}),
'HassTurnOn': dict({
'description': "Turns on/opens/presses a device or entity. For locks, this performs a 'lock' action. Use for requests like 'turn on', 'activate', 'enable', or 'lock'.",
'parameters': dict({
'properties': dict({
'area': dict({
'type': 'string',
}),
'device_class': dict({
'items': dict({
'enum': list([
'awning',
'blind',
'curtain',
'damper',
'door',
'garage',
'gas',
'gate',
'identify',
'outlet',
'projector',
'receiver',
'restart',
'shade',
'shutter',
'speaker',
'switch',
'tv',
'update',
'water',
'window',
]),
'type': 'string',
}),
'type': 'array',
}),
'domain': dict({
'items': dict({
'type': 'string',
}),
'type': 'array',
}),
'floor': dict({
'type': 'string',
}),
'name': dict({
'type': 'string',
}),
}),
'required': list([
]),
'type': 'object',
}),
}),
'HassUnpauseTimer': dict({
'description': 'Resumes a paused timer',
'parameters': dict({
'properties': dict({
'area': dict({
'type': 'string',
}),
'name': dict({
'type': 'string',
}),
'start_hours': dict({
'minimum': 0,
'type': 'integer',
}),
'start_minutes': dict({
'minimum': 0,
'type': 'integer',
}),
'start_seconds': dict({
'minimum': 0,
'type': 'integer',
}),
}),
'required': list([
]),
'type': 'object',
}),
}),
'calendar_get_events': dict({
'description': "Get events from a calendar. When asked if something happens, search the whole week. Results are RFC 5545 which means 'end' is exclusive.",
'parameters': dict({
'properties': dict({
'calendar': dict({
'enum': list([
'Personal',
]),
'type': 'string',
}),
'range': dict({
'enum': list([
'today',
'week',
]),
'type': 'string',
}),
}),
'required': list([
'calendar',
'range',
]),
'type': 'object',
}),
}),
'test_script': dict({
'description': "This is a test script. Aliases: ['test_script']",
'parameters': dict({
'properties': dict({
'beer': dict({
'description': 'Number of beers',
'type': 'string',
}),
'wine': dict({
'type': 'string',
}),
}),
'required': list([
]),
'type': 'object',
}),
}),
'todo_get_items': dict({
'description': "Query a to-do list to find out what items are on it. Use this to answer questions like 'What's on my task list?' or 'Read my grocery list'. Filters items by status (needs_action, completed, all).",
'parameters': dict({
'properties': dict({
'status': dict({
'default': 'needs_action',
'description': 'Filter returned items by status, by default returns incomplete items',
'enum': list([
'all',
'completed',
'needs_action',
]),
'type': 'string',
}),
'todo_list': dict({
'enum': list([
'Shopping',
]),
'type': 'string',
}),
}),
'required': list([
'todo_list',
]),
'type': 'object',
}),
}),
})
# ---
+260 -25
View File
@@ -2,16 +2,19 @@
from datetime import timedelta
from decimal import Decimal
from typing import Any
from unittest.mock import patch
import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import calendar, todo
from homeassistant.components import calendar, script, todo
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.components.intent import async_register_timer_handler
from homeassistant.components.script import ScriptConfig
from homeassistant.core import Context, HomeAssistant, State, SupportsResponse
from homeassistant.core import Context, HomeAssistant, State, SupportsResponse, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import (
area_registry as ar,
@@ -53,6 +56,24 @@ class MyAPI(llm.API):
return llm.APIInstance(self, self.prompt, llm_context, self.tools)
class _StubTool(llm.Tool):
"""Minimal tool for registry tests."""
def __init__(self, name: str) -> None:
"""Initialize the stub tool."""
self.name = name
self.description = f"{name} description"
async def async_call(
self,
hass: HomeAssistant,
tool_input: llm.ToolInput,
llm_context: llm.LLMContext,
) -> JsonObjectType:
"""Return an empty result."""
return {}
async def test_get_api_no_existing(
hass: HomeAssistant, llm_context: llm.LLMContext
) -> None:
@@ -128,6 +149,72 @@ async def test_multiple_apis(hass: HomeAssistant, llm_context: llm.LLMContext) -
assert await llm.async_get_api(hass, "test-2", llm_context)
async def test_register_tool_provider(
hass: HomeAssistant, llm_context: llm.LLMContext
) -> None:
"""Test registering and unregistering a tool provider."""
tool = _StubTool("my_tool")
@callback
def provider(_hass: HomeAssistant, _llm_context: llm.LLMContext) -> llm.LLMTools:
return llm.LLMTools(tools=[tool], prompt="use my_tool wisely")
unreg = llm.async_register_tool_provider(hass, provider, apis={"assist": None})
result = llm._async_get_registered_tools(hass, "assist", llm_context)
assert result.tools == [tool]
assert result.prompt == "use my_tool wisely"
unreg()
result = llm._async_get_registered_tools(hass, "assist", llm_context)
assert result.tools == []
assert result.prompt is None
async def test_register_tool_static(
hass: HomeAssistant, llm_context: llm.LLMContext
) -> None:
"""Test the static single-tool registration convenience."""
tool = _StubTool("static_tool")
unreg = llm.async_register_tool(hass, tool, apis={"assist": None})
result = llm._async_get_registered_tools(hass, "assist", llm_context)
assert result.tools == [tool]
assert result.prompt is None
unreg()
assert llm._async_get_registered_tools(hass, "assist", llm_context).tools == []
async def test_register_tool_provider_multiple_apis(
hass: HomeAssistant, llm_context: llm.LLMContext
) -> None:
"""Test a provider in multiple APIs and prompt merging across providers."""
tool_a = _StubTool("tool_a")
tool_b = _StubTool("tool_b")
@callback
def provider_a(_hass: HomeAssistant, _llm_context: llm.LLMContext) -> llm.LLMTools:
return llm.LLMTools(tools=[tool_a], prompt="prompt a")
@callback
def provider_b(_hass: HomeAssistant, _llm_context: llm.LLMContext) -> llm.LLMTools:
return llm.LLMTools(tools=[tool_b], prompt="prompt b")
llm.async_register_tool_provider(
hass, provider_a, apis={"assist": None, "other": None}
)
llm.async_register_tool_provider(hass, provider_b, apis={"assist": None})
assist = llm._async_get_registered_tools(hass, "assist", llm_context)
assert assist.tools == [tool_a, tool_b]
assert assist.prompt == "prompt a\nprompt b"
other = llm._async_get_registered_tools(hass, "other", llm_context)
assert other.tools == [tool_a]
assert other.prompt == "prompt a"
async def test_call_tool_no_existing(
hass: HomeAssistant, llm_context: llm.LLMContext
) -> None:
@@ -148,6 +235,7 @@ async def test_assist_api(
) -> None:
"""Test Assist API."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "llm", {})
entity_registry.async_get_or_create(
"light",
@@ -327,12 +415,121 @@ async def test_assist_api(
}
def _normalize_schema(value: Any) -> Any:
"""Recursively sort scalar lists (e.g. enum values) for a stable snapshot.
Some tool parameter schemas build enum options from Python sets, so their
order varies per process. Order is semantically irrelevant here.
"""
if isinstance(value, dict):
return {key: _normalize_schema(val) for key, val in value.items()}
if isinstance(value, list):
items = [_normalize_schema(item) for item in value]
if all(isinstance(item, (str, int, float, bool)) for item in items):
return sorted(items, key=repr)
return items
return value
async def test_assist_api_snapshot(
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
device_registry: dr.DeviceRegistry,
snapshot: SnapshotAssertion,
) -> None:
"""Golden snapshot of the Assist API prompt + tools.
Behavior-parity net for the v1 tool-platform refactor: the assembled prompt
and the full serialized tool set (name, description, parameters) must stay
identical as built-in tools and intents move out of AssistAPI into per-
integration platforms.
"""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {})
assert await async_setup_component(hass, "llm", {})
assert await async_setup_component(hass, "calendar", {})
assert await async_setup_component(hass, "todo", {})
assert await async_setup_component(
hass,
"script",
{
"script": {
"test_script": {
"description": "This is a test script",
"sequence": [],
"fields": {
"beer": {"description": "Number of beers"},
"wine": {},
},
}
}
},
)
entry = MockConfigEntry(title=None)
entry.add_to_hass(hass)
device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections={("test", "1234")},
suggested_area="Test Area",
)
# Expose one entity per tool-bearing domain so every built-in tool appears.
for domain, object_id, name in (
("light", "kitchen", "Kitchen"),
("calendar", "personal", "Personal"),
("todo", "shopping", "Shopping"),
):
created = entity_registry.async_get_or_create(
domain,
"test",
f"mock-{object_id}",
original_name=name,
suggested_object_id=object_id,
)
hass.states.async_set(created.entity_id, "on", {"friendly_name": name})
async_expose_entity(hass, "conversation", created.entity_id, True)
async_expose_entity(hass, "conversation", "script.test_script", True)
async_register_timer_handler(hass, device.id, lambda *args: None)
llm_context = llm.LLMContext(
platform="test_platform",
context=Context(),
language="*",
assistant="conversation",
device_id=device.id,
)
api = await llm.async_get_api(hass, "assist", llm_context)
# Order-independent: the v1 refactor moves tools/prompt fragments out of
# AssistAPI into per-integration registrations, which changes their order but
# not their content. Compare tools as a name-keyed mapping and the prompt as a
# set of lines so each migration step verifies same-tools/same-content/
# same-prompt regardless of ordering. (Final ordering vs dev is checked
# separately at the end.)
assert sorted(
line for line in api.api_prompt.splitlines() if line.strip()
) == snapshot(name="prompt")
assert {
tool.name: {
"description": tool.description,
"parameters": _normalize_schema(
convert(tool.parameters, custom_serializer=api.custom_serializer)
),
}
for tool in api.tools
} == snapshot(name="tools")
async def test_assist_api_get_timer_tools(
hass: HomeAssistant, llm_context: llm.LLMContext
) -> None:
"""Test getting timer tools with Assist API."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {})
assert await async_setup_component(hass, "llm", {})
await hass.async_block_till_done()
api = await llm.async_get_api(hass, "assist", llm_context)
assert "HassStartTimer" not in [tool.name for tool in api.tools]
@@ -351,6 +548,7 @@ async def test_assist_api_tools(
"""Test getting timer tools with Assist API."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {})
assert await async_setup_component(hass, "llm", {})
llm_context.device_id = "test_device"
@@ -381,10 +579,32 @@ async def test_assist_api_tools(
]
async def test_assist_api_registered_tools(
hass: HomeAssistant, llm_context: llm.LLMContext
) -> None:
"""Test Assist API includes tools and prompt from the tool registry."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {})
tool = _StubTool("registered_tool")
@callback
def provider(_hass: HomeAssistant, _llm_context: llm.LLMContext) -> llm.LLMTools:
return llm.LLMTools(tools=[tool], prompt="registered prompt fragment")
llm.async_register_tool_provider(hass, provider, apis={"assist": None})
api = await llm.async_get_api(hass, "assist", llm_context)
assert "registered_tool" in [t.name for t in api.tools]
assert "registered prompt fragment" in api.api_prompt
async def test_assist_api_description(
hass: HomeAssistant, llm_context: llm.LLMContext
) -> None:
"""Test intent description with Assist API."""
assert await async_setup_component(hass, "llm", {})
class MyIntentHandler(intent.IntentHandler):
intent_type = "test_intent"
@@ -410,6 +630,7 @@ async def test_assist_api_prompt(
"""Test prompt for the assist API."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {})
assert await async_setup_component(hass, "llm", {})
context = Context()
llm_context = llm.LLMContext(
platform="test_platform",
@@ -418,6 +639,7 @@ async def test_assist_api_prompt(
assistant="conversation",
device_id=None,
)
await hass.async_block_till_done()
api = await llm.async_get_api(hass, "assist", llm_context)
assert api.api_prompt == (
@@ -762,11 +984,11 @@ Static Context: An overview of the areas and the devices in this smart home:
)
api = await llm.async_get_api(hass, "assist", llm_context)
assert api.api_prompt == (
f"""{first_part_prompt}
{dynamic_context_prompt}
{stateless_exposed_entities_prompt}
f"""{stateless_exposed_entities_prompt}
{area_prompt}
{no_timer_prompt}"""
{first_part_prompt}
{no_timer_prompt}
{dynamic_context_prompt}"""
)
# Verify that the GetLiveContext tool returns the same results
@@ -787,11 +1009,11 @@ Static Context: An overview of the areas and the devices in this smart home:
)
api = await llm.async_get_api(hass, "assist", llm_context)
assert api.api_prompt == (
f"""{first_part_prompt}
{dynamic_context_prompt}
{stateless_exposed_entities_prompt}
f"""{stateless_exposed_entities_prompt}
{area_prompt}
{no_timer_prompt}"""
{first_part_prompt}
{no_timer_prompt}
{dynamic_context_prompt}"""
)
# Add floor
@@ -804,11 +1026,11 @@ Static Context: An overview of the areas and the devices in this smart home:
)
api = await llm.async_get_api(hass, "assist", llm_context)
assert api.api_prompt == (
f"""{first_part_prompt}
{dynamic_context_prompt}
{stateless_exposed_entities_prompt}
f"""{stateless_exposed_entities_prompt}
{area_prompt}
{no_timer_prompt}"""
{first_part_prompt}
{no_timer_prompt}
{dynamic_context_prompt}"""
)
# Register device for timers
@@ -817,10 +1039,10 @@ Static Context: An overview of the areas and the devices in this smart home:
api = await llm.async_get_api(hass, "assist", llm_context)
# The no_timer_prompt is gone
assert api.api_prompt == (
f"""{first_part_prompt}
{dynamic_context_prompt}
{stateless_exposed_entities_prompt}
{area_prompt}"""
f"""{stateless_exposed_entities_prompt}
{area_prompt}
{first_part_prompt}
{dynamic_context_prompt}"""
)
@@ -833,6 +1055,7 @@ async def test_get_live_context_tool_filter(
"""Test the filter parameters of the GetLiveContext tool."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {})
assert await async_setup_component(hass, "llm", {})
context = Context()
llm_context = llm.LLMContext(
platform="test_platform",
@@ -929,6 +1152,7 @@ async def test_get_live_context_tool_filter(
hass.states.async_set(office_ac.entity_id, "cool")
hass.states.async_set(kitchen_ac.entity_id, "heat")
await hass.async_block_till_done()
api = await llm.async_get_api(hass, "assist", llm_context)
# Filter by area and domain (example 1)
@@ -1168,6 +1392,7 @@ async def test_script_tool(
"""Test ScriptTool for the assist API."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "intent", {})
assert await async_setup_component(hass, "llm", {})
context = Context()
llm_context = llm.LLMContext(
platform="test_platform",
@@ -1223,11 +1448,11 @@ async def test_script_tool(
area = area_registry.async_create("Living room")
floor = floor_registry.async_create("2")
assert llm.ACTION_PARAMETERS_CACHE not in hass.data
assert script.llm.ACTION_PARAMETERS_CACHE not in hass.data
api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
tools = [tool for tool in api.tools if isinstance(tool, script.llm.ScriptTool)]
assert len(tools) == 2
tool = tools[0]
@@ -1247,7 +1472,7 @@ async def test_script_tool(
}
assert tool.parameters.schema == schema
assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {
assert hass.data[script.llm.ACTION_PARAMETERS_CACHE]["script"] == {
"test_script": (
"This is a test script. Aliases: ['script alias', 'script name']",
vol.Schema(schema),
@@ -1347,11 +1572,11 @@ async def test_script_tool(
):
await hass.services.async_call("script", "reload", blocking=True)
assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {}
assert hass.data[script.llm.ACTION_PARAMETERS_CACHE]["script"] == {}
api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
tools = [tool for tool in api.tools if isinstance(tool, script.llm.ScriptTool)]
assert len(tools) == 2
tool = tools[0]
@@ -1363,7 +1588,7 @@ async def test_script_tool(
schema = {vol.Required("beer", description="Number of beers"): cv.string}
assert tool.parameters.schema == schema
assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {
assert hass.data[script.llm.ACTION_PARAMETERS_CACHE]["script"] == {
"test_script": (
"This is a new test script. Aliases: ['script alias', 'script name']",
vol.Schema(schema),
@@ -1378,6 +1603,7 @@ async def test_script_tool(
async def test_script_tool_name(hass: HomeAssistant) -> None:
"""Test that script tool name is not started with a digit."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "llm", {})
context = Context()
llm_context = llm.LLMContext(
platform="test_platform",
@@ -1407,7 +1633,7 @@ async def test_script_tool_name(hass: HomeAssistant) -> None:
api = await llm.async_get_api(hass, "assist", llm_context)
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
tools = [tool for tool in api.tools if isinstance(tool, script.llm.ScriptTool)]
assert len(tools) == 1
tool = tools[0]
@@ -1692,6 +1918,8 @@ async def test_selector_serializer(
async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
"""Test the calendar get events tool."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "llm", {})
assert await async_setup_component(hass, "calendar", {})
hass.states.async_set(
"calendar.test_calendar", "on", {"friendly_name": "Mock Calendar Name"}
)
@@ -1704,6 +1932,8 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
assistant="conversation",
device_id=None,
)
# Wait for the llm integration to discover the calendar tools platform.
await hass.async_block_till_done()
api = await llm.async_get_api(hass, "assist", llm_context)
tool = next(
(tool for tool in api.tools if tool.name == "calendar_get_events"), None
@@ -1793,6 +2023,7 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
async def test_todo_get_items_tool(hass: HomeAssistant) -> None:
"""Test the todo get items tool."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "llm", {})
assert await async_setup_component(hass, "todo", {})
hass.states.async_set(
"todo.test_list", "0", {"friendly_name": "Mock Todo List Name"}
@@ -1806,6 +2037,8 @@ async def test_todo_get_items_tool(hass: HomeAssistant) -> None:
assistant="conversation",
device_id=None,
)
# Wait for the llm integration to discover the todo tools platform.
await hass.async_block_till_done()
api = await llm.async_get_api(hass, "assist", llm_context)
tool = next((tool for tool in api.tools if tool.name == "todo_get_items"), None)
assert tool is not None
@@ -1905,6 +2138,7 @@ async def test_get_date_time_tool(hass: HomeAssistant) -> None:
"""Test the GetDateTime tool."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "llm", {})
context = Context()
llm_context = llm.LLMContext(
platform="test_platform",
@@ -1939,6 +2173,7 @@ async def test_get_date_time_tool(hass: HomeAssistant) -> None:
async def test_no_tools_exposed(hass: HomeAssistant) -> None:
"""Test that tools are not exposed when no entities are exposed."""
assert await async_setup_component(hass, "homeassistant", {})
assert await async_setup_component(hass, "llm", {})
context = Context()
llm_context = llm.LLMContext(
platform="test_platform",