mirror of
https://github.com/home-assistant/core.git
synced 2026-06-18 09:52:57 +02:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 77fe61db3e | |||
| ac3a1493aa | |||
| fbda9f4fc3 | |||
| 5802cada22 | |||
| f357dd3d49 | |||
| 1eca57a01b | |||
| 0dde553653 | |||
| 5943b70577 | |||
| 99ee92029d | |||
| 717acdf771 | |||
| e76d83241e | |||
| 6024f96d56 | |||
| 5eec6167cc | |||
| 184025cfb4 |
Generated
+2
@@ -1022,6 +1022,8 @@ CLAUDE.md @home-assistant/core
|
||||
/tests/components/litterrobot/ @natekspencer @tkdrob
|
||||
/homeassistant/components/livisi/ @StefanIacobLivisi @planbnet
|
||||
/tests/components/livisi/ @StefanIacobLivisi @planbnet
|
||||
/homeassistant/components/llm/ @home-assistant/core
|
||||
/tests/components/llm/ @home-assistant/core
|
||||
/homeassistant/components/local_calendar/ @allenporter
|
||||
/tests/components/local_calendar/ @allenporter
|
||||
/homeassistant/components/local_ip/ @issacg
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
"""LLM tools for the calendar integration."""
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import intent, llm
|
||||
from homeassistant.util import dt as dt_util
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
from . import SERVICE_GET_EVENTS
|
||||
from .const import DOMAIN
|
||||
|
||||
|
||||
async def async_setup_tools(hass: HomeAssistant) -> None:
|
||||
"""Set up the calendar LLM tools."""
|
||||
llm.async_register_tool_provider(
|
||||
hass, _calendar_tools, apis={llm.LLM_API_ASSIST: None}
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def _calendar_tools(hass: HomeAssistant, llm_context: llm.LLMContext) -> llm.LLMTools:
|
||||
"""Return the calendar tools for the exposed calendars."""
|
||||
if llm_context.assistant is None:
|
||||
return llm.LLMTools(tools=[])
|
||||
|
||||
exposed = llm.async_get_exposed_entities(
|
||||
hass, llm_context.assistant, include_state=False
|
||||
)
|
||||
if not exposed[DOMAIN]:
|
||||
return llm.LLMTools(tools=[])
|
||||
|
||||
names = []
|
||||
for info in exposed[DOMAIN].values():
|
||||
names.extend(info["names"].split(", "))
|
||||
return llm.LLMTools(tools=[CalendarGetEventsTool(names)])
|
||||
|
||||
|
||||
class CalendarGetEventsTool(llm.Tool):
|
||||
"""LLM Tool allowing querying a calendar."""
|
||||
|
||||
name = "calendar_get_events"
|
||||
description = (
|
||||
"Get events from a calendar. "
|
||||
"When asked if something happens, search the whole week. "
|
||||
"Results are RFC 5545 which means 'end' is exclusive."
|
||||
)
|
||||
|
||||
def __init__(self, calendars: list[str]) -> None:
|
||||
"""Init the get events tool."""
|
||||
self.parameters = vol.Schema(
|
||||
{
|
||||
vol.Required("calendar"): vol.In(calendars),
|
||||
vol.Required("range"): vol.In(["today", "week"]),
|
||||
}
|
||||
)
|
||||
|
||||
async def async_call(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
tool_input: llm.ToolInput,
|
||||
llm_context: llm.LLMContext,
|
||||
) -> JsonObjectType:
|
||||
"""Query a calendar."""
|
||||
data = self.parameters(tool_input.tool_args)
|
||||
result = intent.async_match_targets(
|
||||
hass,
|
||||
intent.MatchTargetsConstraints(
|
||||
name=data["calendar"],
|
||||
domains=[DOMAIN],
|
||||
assistant=llm_context.assistant,
|
||||
),
|
||||
)
|
||||
if not result.is_match:
|
||||
return {"success": False, "error": "Calendar not found"}
|
||||
|
||||
entity_id = result.states[0].entity_id
|
||||
if data["range"] == "today":
|
||||
start = dt_util.now()
|
||||
end = dt_util.start_of_local_day() + timedelta(days=1)
|
||||
elif data["range"] == "week":
|
||||
start = dt_util.now()
|
||||
end = dt_util.start_of_local_day() + timedelta(days=7)
|
||||
|
||||
service_data = {
|
||||
"entity_id": entity_id,
|
||||
"start_date_time": start.isoformat(),
|
||||
"end_date_time": end.isoformat(),
|
||||
}
|
||||
|
||||
service_result = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
SERVICE_GET_EVENTS,
|
||||
service_data,
|
||||
context=llm_context.context,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
events = [
|
||||
event if "T" in event["start"] else {**event, "all_day": True}
|
||||
for event in cast(dict, service_result)[entity_id]["events"]
|
||||
]
|
||||
|
||||
return {"success": True, "result": events}
|
||||
@@ -2,7 +2,7 @@
|
||||
"domain": "conversation",
|
||||
"name": "Conversation",
|
||||
"codeowners": ["@home-assistant/core", "@synesthesiam", "@arturpragacz"],
|
||||
"dependencies": ["http", "intent"],
|
||||
"dependencies": ["http", "intent", "llm"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/conversation",
|
||||
"integration_type": "entity",
|
||||
"quality_scale": "internal",
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
"""The LLM integration.
|
||||
|
||||
Owns the LLM tools platform: integrations contribute tools to the LLM APIs
|
||||
through an ``<integration>/llm.py`` platform with an ``async_setup_tools`` hook,
|
||||
discovered here. The framework (registry, ``Tool``, the APIs) lives in
|
||||
``homeassistant.helpers.llm``; this integration owns the lifecycle, mirroring the
|
||||
``intent`` helper/integration split.
|
||||
"""
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv, llm
|
||||
from homeassistant.helpers.integration_platform import (
|
||||
async_process_integration_platforms,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DOMAIN
|
||||
from .intents import intent_tools
|
||||
from .tools import DYNAMIC_CONTEXT_PROMPT, GetDateTimeTool, GetLiveContextTool
|
||||
|
||||
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||
|
||||
|
||||
class LLMToolsPlatformProtocol(Protocol):
|
||||
"""Define the format that LLM tools platforms can have."""
|
||||
|
||||
async def async_setup_tools(self, hass: HomeAssistant) -> None:
|
||||
"""Set up the integration's LLM tools."""
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the LLM integration."""
|
||||
llm.async_register_tool_provider(
|
||||
hass, intent_tools, apis={llm.LLM_API_ASSIST: None}
|
||||
)
|
||||
llm.async_register_tool(hass, GetDateTimeTool(), apis={llm.LLM_API_ASSIST: None})
|
||||
llm.async_register_tool_provider(
|
||||
hass, _live_context_tools, apis={llm.LLM_API_ASSIST: None}
|
||||
)
|
||||
await async_process_integration_platforms(
|
||||
hass, DOMAIN, _async_process_llm_tools_platform
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@callback
|
||||
def _live_context_tools(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> llm.LLMTools:
|
||||
"""Return the live-context tool and its prompt when entities are exposed."""
|
||||
if llm_context.assistant is None:
|
||||
return llm.LLMTools(tools=[])
|
||||
|
||||
exposed = llm.async_get_exposed_entities(
|
||||
hass, llm_context.assistant, include_state=False
|
||||
)
|
||||
exposed_domains = {info["domain"] for info in exposed["entities"].values()}
|
||||
if not exposed_domains:
|
||||
return llm.LLMTools(tools=[])
|
||||
|
||||
return llm.LLMTools(tools=[GetLiveContextTool()], prompt=DYNAMIC_CONTEXT_PROMPT)
|
||||
|
||||
|
||||
async def _async_process_llm_tools_platform(
|
||||
hass: HomeAssistant, domain: str, platform: LLMToolsPlatformProtocol
|
||||
) -> None:
|
||||
"""Register the LLM tools of an integration."""
|
||||
await platform.async_setup_tools(hass)
|
||||
@@ -0,0 +1,3 @@
|
||||
"""Constants for the LLM integration."""
|
||||
|
||||
DOMAIN = "llm"
|
||||
@@ -0,0 +1,88 @@
|
||||
"""LLM tools wrapping exposed intents."""
|
||||
|
||||
from functools import cache, partial
|
||||
|
||||
import slugify as unicode_slug
|
||||
|
||||
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
||||
from homeassistant.components.intent import async_device_supports_timers
|
||||
from homeassistant.components.weather import INTENT_GET_WEATHER
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import intent, llm
|
||||
|
||||
DEVICE_CONTROL_TOOL_USAGE_PROMPT = (
|
||||
"When controlling Home Assistant always call the intent tools. "
|
||||
"Use HassTurnOn to lock and HassTurnOff to unlock a lock. "
|
||||
"When controlling a device, prefer passing just name and domain. "
|
||||
"When controlling an area, prefer passing just area name and domain."
|
||||
)
|
||||
|
||||
IGNORE_INTENTS = {
|
||||
intent.INTENT_GET_TEMPERATURE,
|
||||
INTENT_GET_WEATHER,
|
||||
INTENT_OPEN_COVER, # deprecated
|
||||
INTENT_CLOSE_COVER, # deprecated
|
||||
intent.INTENT_GET_STATE,
|
||||
intent.INTENT_NEVERMIND,
|
||||
intent.INTENT_TOGGLE,
|
||||
intent.INTENT_GET_CURRENT_DATE,
|
||||
intent.INTENT_GET_CURRENT_TIME,
|
||||
intent.INTENT_RESPOND,
|
||||
}
|
||||
|
||||
TIMER_INTENTS = {
|
||||
intent.INTENT_START_TIMER,
|
||||
intent.INTENT_CANCEL_TIMER,
|
||||
intent.INTENT_INCREASE_TIMER,
|
||||
intent.INTENT_DECREASE_TIMER,
|
||||
intent.INTENT_PAUSE_TIMER,
|
||||
intent.INTENT_UNPAUSE_TIMER,
|
||||
intent.INTENT_TIMER_STATUS,
|
||||
}
|
||||
|
||||
_slugify = cache(partial(unicode_slug.slugify, separator="_", lowercase=False))
|
||||
|
||||
|
||||
@callback
|
||||
def intent_tools(hass: HomeAssistant, llm_context: llm.LLMContext) -> llm.LLMTools:
|
||||
"""Return the intent tools and their prompt for the exposed entities."""
|
||||
exposed = (
|
||||
llm.async_get_exposed_entities(hass, llm_context.assistant, include_state=False)
|
||||
if llm_context.assistant
|
||||
else None
|
||||
)
|
||||
|
||||
ignore = IGNORE_INTENTS
|
||||
if not llm_context.device_id or not async_device_supports_timers(
|
||||
hass, llm_context.device_id
|
||||
):
|
||||
ignore = ignore | TIMER_INTENTS
|
||||
|
||||
handlers = [
|
||||
handler
|
||||
for handler in intent.async_get(hass)
|
||||
if handler.intent_type not in ignore
|
||||
]
|
||||
|
||||
if exposed is not None:
|
||||
exposed_domains = {info["domain"] for info in exposed["entities"].values()}
|
||||
handlers = [
|
||||
handler
|
||||
for handler in handlers
|
||||
if handler.platforms is None or handler.platforms & exposed_domains
|
||||
]
|
||||
|
||||
tools: list[llm.Tool] = [
|
||||
llm.IntentTool(_slugify(handler.intent_type), handler) for handler in handlers
|
||||
]
|
||||
|
||||
prompt = None
|
||||
if exposed and exposed["entities"]:
|
||||
parts = [DEVICE_CONTROL_TOOL_USAGE_PROMPT]
|
||||
if not llm_context.device_id or not async_device_supports_timers(
|
||||
hass, llm_context.device_id
|
||||
):
|
||||
parts.append("This device is not able to start timers.")
|
||||
prompt = "\n".join(parts)
|
||||
|
||||
return llm.LLMTools(tools=tools, prompt=prompt)
|
||||
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"domain": "llm",
|
||||
"name": "LLM",
|
||||
"after_dependencies": ["intent"],
|
||||
"codeowners": ["@home-assistant/core"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/llm",
|
||||
"integration_type": "system",
|
||||
"iot_class": "calculated",
|
||||
"quality_scale": "internal"
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
"""Tools for the LLM integration."""
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv, intent
|
||||
from homeassistant.helpers.llm import (
|
||||
NO_ENTITIES_PROMPT,
|
||||
LLMContext,
|
||||
Tool,
|
||||
ToolInput,
|
||||
async_get_exposed_entities,
|
||||
)
|
||||
from homeassistant.util import dt as dt_util, yaml as yaml_util
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
DYNAMIC_CONTEXT_PROMPT = (
|
||||
"You ARE equipped to answer questions about the"
|
||||
" current state of\n"
|
||||
"the home using the `GetLiveContext` tool."
|
||||
" This is a primary function."
|
||||
" Do not state you lack the\n"
|
||||
"functionality if the question requires live data.\n"
|
||||
"If the user asks about device existence/type"
|
||||
' (e.g., "Do I have lights in the bedroom?"):'
|
||||
" Answer\n"
|
||||
"from the static context below.\n"
|
||||
"If the user asks about the CURRENT state, value,"
|
||||
' or mode (e.g., "Is the lock locked?",\n'
|
||||
'"Is the fan on?",'
|
||||
' "What mode is the thermostat in?",'
|
||||
' "What is the temperature outside?"):\n'
|
||||
" 1. Recognize this requires live data.\n"
|
||||
" 2. You MUST call `GetLiveContext`."
|
||||
" This tool will provide the needed real-time"
|
||||
" information (like temperature from the local"
|
||||
" weather, lock status, etc.).\n"
|
||||
" 3. Use the tool's response** to answer the"
|
||||
" user accurately"
|
||||
' (e.g., "The temperature outside is'
|
||||
' [value from tool].").\n'
|
||||
"For general knowledge questions not about the"
|
||||
" home: Answer truthfully from internal"
|
||||
" knowledge.\n"
|
||||
)
|
||||
|
||||
|
||||
class GetDateTimeTool(Tool):
|
||||
"""Tool for getting the current date and time."""
|
||||
|
||||
name = "GetDateTime"
|
||||
description = "Provides the current date and time."
|
||||
|
||||
async def async_call(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
tool_input: ToolInput,
|
||||
llm_context: LLMContext,
|
||||
) -> JsonObjectType:
|
||||
"""Get the current date and time."""
|
||||
now = dt_util.now()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": {
|
||||
"date": now.strftime("%Y-%m-%d"),
|
||||
"time": now.strftime("%H:%M:%S"),
|
||||
"timezone": now.strftime("%Z"),
|
||||
"weekday": now.strftime("%A"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _live_context_match_error(
|
||||
match_result: intent.MatchTargetsResult,
|
||||
name_filter: str | None,
|
||||
area_filter: str | None,
|
||||
domain_filter: list[str] | None,
|
||||
) -> str:
|
||||
"""Build an actionable error message for a failed GetLiveContext match."""
|
||||
reason = match_result.no_match_reason
|
||||
if reason is intent.MatchFailedReason.INVALID_AREA:
|
||||
return f"Area '{match_result.no_match_name}' does not exist"
|
||||
if reason is intent.MatchFailedReason.NAME:
|
||||
return f"No exposed entities matched name '{name_filter}'"
|
||||
if reason is intent.MatchFailedReason.AREA:
|
||||
return f"No exposed entities found in area '{area_filter}'"
|
||||
if reason is intent.MatchFailedReason.DOMAIN:
|
||||
domains = ", ".join(domain_filter) if domain_filter else ""
|
||||
return f"No exposed entities found in domain(s): {domains}"
|
||||
return "No entities matched the provided filter"
|
||||
|
||||
|
||||
class GetLiveContextTool(Tool):
|
||||
"""Tool for getting the current state of exposed entities.
|
||||
|
||||
This returns state for all entities that have been exposed to
|
||||
the assistant. This is different than the GetState intent, which
|
||||
returns state for entities based on intent parameters.
|
||||
"""
|
||||
|
||||
name = "GetLiveContext"
|
||||
description = (
|
||||
"Provides real-time information about the"
|
||||
" CURRENT state, value, or mode of devices,"
|
||||
" sensors, entities, or areas. "
|
||||
"Use this tool for: "
|
||||
"1. Answering questions about current"
|
||||
" conditions (e.g., 'Is the light on?'). "
|
||||
"2. As the first step in conditional actions"
|
||||
" (e.g., 'If the weather is rainy, turn off"
|
||||
" sprinklers' requires checking the weather"
|
||||
" first). "
|
||||
"You may filter for devices by name, domain,"
|
||||
" and area, including combining those"
|
||||
" filters. "
|
||||
"Prefer filtering by domain when searching"
|
||||
" for multiple devices of the same type."
|
||||
)
|
||||
parameters = vol.Schema(
|
||||
{
|
||||
vol.Optional(
|
||||
"name",
|
||||
description="Filter entities by name or alias (case-insensitive).",
|
||||
): cv.string,
|
||||
vol.Optional(
|
||||
"domain",
|
||||
description=(
|
||||
"Filter entities by domain"
|
||||
" (e.g. 'light', 'sensor')."
|
||||
" Accepts a single domain or a list."
|
||||
),
|
||||
): vol.Any(cv.string, [cv.string]),
|
||||
vol.Optional(
|
||||
"area",
|
||||
description="Filter entities by area name or alias (case-insensitive).",
|
||||
): cv.string,
|
||||
}
|
||||
)
|
||||
|
||||
async def async_call(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
tool_input: ToolInput,
|
||||
llm_context: LLMContext,
|
||||
) -> JsonObjectType:
|
||||
"""Get the current state of exposed entities."""
|
||||
if llm_context.assistant is None:
|
||||
# Note this doesn't happen in practice since this tool won't be
|
||||
# exposed if no assistant is configured.
|
||||
return {"success": False, "error": "No assistant configured"}
|
||||
|
||||
args = self.parameters(tool_input.tool_args)
|
||||
exposed_entities = async_get_exposed_entities(
|
||||
hass, llm_context.assistant, include_state=True
|
||||
)
|
||||
|
||||
if not exposed_entities["entities"]:
|
||||
return {"success": False, "error": NO_ENTITIES_PROMPT}
|
||||
|
||||
name_filter = args.get("name")
|
||||
area_filter = args.get("area")
|
||||
domain_filter = args.get("domain")
|
||||
|
||||
if isinstance(domain_filter, str):
|
||||
domain_filter = [domain_filter]
|
||||
|
||||
if domain_filter is not None:
|
||||
domain_filter = [
|
||||
normalized_domain
|
||||
for domain in domain_filter
|
||||
if (normalized_domain := domain.strip().lower())
|
||||
]
|
||||
|
||||
if name_filter or area_filter or domain_filter:
|
||||
exposed_states = [
|
||||
state
|
||||
for entity_id in exposed_entities["entities"]
|
||||
if (state := hass.states.get(entity_id)) is not None
|
||||
]
|
||||
match_result = intent.async_match_targets(
|
||||
hass,
|
||||
intent.MatchTargetsConstraints(
|
||||
name=name_filter,
|
||||
area_name=area_filter,
|
||||
domains=domain_filter,
|
||||
# This tool only returns context, so multiple entities
|
||||
# sharing a name (e.g. "AC" in two areas) should all be
|
||||
# returned rather than failing as an ambiguous match.
|
||||
allow_duplicate_names=True,
|
||||
),
|
||||
states=exposed_states,
|
||||
)
|
||||
|
||||
if not match_result.is_match:
|
||||
return {
|
||||
"success": False,
|
||||
"error": _live_context_match_error(
|
||||
match_result, name_filter, area_filter, domain_filter
|
||||
),
|
||||
}
|
||||
|
||||
matched_ids = {state.entity_id for state in match_result.states}
|
||||
entities = [
|
||||
info
|
||||
for entity_id, info in exposed_entities["entities"].items()
|
||||
if entity_id in matched_ids
|
||||
]
|
||||
else:
|
||||
entities = list(exposed_entities["entities"].values())
|
||||
|
||||
prompt = [
|
||||
"Live Context: An overview of the areas"
|
||||
" and the devices in this smart home:",
|
||||
yaml_util.dump(entities),
|
||||
]
|
||||
return {
|
||||
"success": True,
|
||||
"result": "\n".join(prompt),
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
"name": "Model Context Protocol Server",
|
||||
"codeowners": ["@allenporter"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["http", "conversation"],
|
||||
"dependencies": ["http", "conversation", "llm"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/mcp_server",
|
||||
"integration_type": "service",
|
||||
"iot_class": "local_push",
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
"""LLM tools for the script integration."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import (
|
||||
ATTR_DOMAIN,
|
||||
ATTR_SERVICE,
|
||||
EVENT_HOMEASSISTANT_CLOSE,
|
||||
EVENT_SERVICE_REMOVED,
|
||||
)
|
||||
from homeassistant.core import Event, HomeAssistant, callback, split_entity_id
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
config_validation as cv,
|
||||
entity_registry as er,
|
||||
floor_registry as fr,
|
||||
intent,
|
||||
llm,
|
||||
selector,
|
||||
service,
|
||||
)
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
ACTION_PARAMETERS_CACHE: HassKey[
|
||||
dict[str, dict[str, tuple[str | None, vol.Schema]]]
|
||||
] = HassKey("llm_action_parameters_cache")
|
||||
|
||||
|
||||
async def async_setup_tools(hass: HomeAssistant) -> None:
|
||||
"""Set up the script LLM tools."""
|
||||
llm.async_register_tool_provider(
|
||||
hass, _script_tools, apis={llm.LLM_API_ASSIST: None}
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def _script_tools(hass: HomeAssistant, llm_context: llm.LLMContext) -> llm.LLMTools:
|
||||
"""Return the script tools for the exposed scripts."""
|
||||
if llm_context.assistant is None:
|
||||
return llm.LLMTools(tools=[])
|
||||
|
||||
exposed = llm.async_get_exposed_entities(
|
||||
hass, llm_context.assistant, include_state=False
|
||||
)
|
||||
return llm.LLMTools(
|
||||
tools=[
|
||||
ScriptTool(hass, script_entity_id) for script_entity_id in exposed[DOMAIN]
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_cached_action_parameters(
|
||||
hass: HomeAssistant, domain: str, action: str
|
||||
) -> tuple[str | None, vol.Schema]:
|
||||
"""Get action description and schema."""
|
||||
description = None
|
||||
parameters = vol.Schema({})
|
||||
|
||||
parameters_cache = hass.data.get(ACTION_PARAMETERS_CACHE)
|
||||
|
||||
if parameters_cache is None:
|
||||
parameters_cache = hass.data[ACTION_PARAMETERS_CACHE] = {}
|
||||
|
||||
@callback
|
||||
def clear_cache(event: Event) -> None:
|
||||
"""Clear action parameter cache on action removal."""
|
||||
if (
|
||||
event.data[ATTR_DOMAIN] in parameters_cache
|
||||
and event.data[ATTR_SERVICE]
|
||||
in parameters_cache[event.data[ATTR_DOMAIN]]
|
||||
):
|
||||
parameters_cache[event.data[ATTR_DOMAIN]].pop(event.data[ATTR_SERVICE])
|
||||
|
||||
cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
|
||||
|
||||
@callback
|
||||
def on_homeassistant_close(event: Event) -> None:
|
||||
"""Cleanup."""
|
||||
cancel()
|
||||
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close)
|
||||
|
||||
if domain in parameters_cache and action in parameters_cache[domain]:
|
||||
return parameters_cache[domain][action]
|
||||
|
||||
if action_desc := service.async_get_cached_service_description(
|
||||
hass, domain, action
|
||||
):
|
||||
description = action_desc.get("description")
|
||||
schema: dict[vol.Marker, Any] = {}
|
||||
fields = action_desc.get("fields", {})
|
||||
|
||||
for field, config in fields.items():
|
||||
field_description = config.get("description")
|
||||
if not field_description:
|
||||
field_description = config.get("name")
|
||||
key: vol.Marker
|
||||
if config.get("required"):
|
||||
key = vol.Required(field, description=field_description)
|
||||
else:
|
||||
key = vol.Optional(field, description=field_description)
|
||||
if "selector" in config:
|
||||
schema[key] = selector.selector(config["selector"])
|
||||
else:
|
||||
schema[key] = cv.string
|
||||
|
||||
parameters = vol.Schema(schema)
|
||||
|
||||
if domain == DOMAIN:
|
||||
entity_registry = er.async_get(hass)
|
||||
if (
|
||||
entity_id := entity_registry.async_get_entity_id(domain, domain, action)
|
||||
) is not None and (
|
||||
entity_entry := entity_registry.async_get(entity_id)
|
||||
) is not None:
|
||||
aliases = er.async_get_entity_aliases(hass, entity_entry)
|
||||
if aliases:
|
||||
if description:
|
||||
description = description + ". Aliases: " + str(sorted(aliases))
|
||||
else:
|
||||
description = "Aliases: " + str(sorted(aliases))
|
||||
|
||||
parameters_cache.setdefault(domain, {})[action] = (description, parameters)
|
||||
|
||||
return description, parameters
|
||||
|
||||
|
||||
class ActionTool(llm.Tool):
|
||||
"""LLM Tool representing an action."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
action: str,
|
||||
) -> None:
|
||||
"""Init the class."""
|
||||
self._domain = domain
|
||||
self._action = action
|
||||
self.name = f"{domain}__{action}"
|
||||
# Note: _get_cached_action_parameters only works for services which
|
||||
# add their description directly to the service description cache.
|
||||
# This is not the case for most services, but it is for scripts.
|
||||
# If we want to use `ActionTool` for services other than scripts, we
|
||||
# need to add a coroutine function to fetch the non-cached description
|
||||
# and schema.
|
||||
self.description, self.parameters = _get_cached_action_parameters(
|
||||
hass, domain, action
|
||||
)
|
||||
|
||||
async def async_call(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
tool_input: llm.ToolInput,
|
||||
llm_context: llm.LLMContext,
|
||||
) -> JsonObjectType:
|
||||
"""Call the action."""
|
||||
|
||||
for field, validator in self.parameters.schema.items():
|
||||
if field not in tool_input.tool_args:
|
||||
continue
|
||||
if isinstance(validator, selector.AreaSelector):
|
||||
area_reg = ar.async_get(hass)
|
||||
if validator.config.get("multiple"):
|
||||
areas: list[ar.AreaEntry] = []
|
||||
for area in tool_input.tool_args[field]:
|
||||
areas.extend(intent.find_areas(area, area_reg))
|
||||
tool_input.tool_args[field] = list({area.id for area in areas})
|
||||
else:
|
||||
area = tool_input.tool_args[field]
|
||||
area = list(intent.find_areas(area, area_reg))[0].id
|
||||
tool_input.tool_args[field] = area
|
||||
|
||||
elif isinstance(validator, selector.FloorSelector):
|
||||
floor_reg = fr.async_get(hass)
|
||||
if validator.config.get("multiple"):
|
||||
floors: list[fr.FloorEntry] = []
|
||||
for floor in tool_input.tool_args[field]:
|
||||
floors.extend(intent.find_floors(floor, floor_reg))
|
||||
tool_input.tool_args[field] = list(
|
||||
{floor.floor_id for floor in floors}
|
||||
)
|
||||
else:
|
||||
floor = tool_input.tool_args[field]
|
||||
floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
|
||||
tool_input.tool_args[field] = floor
|
||||
|
||||
result = await hass.services.async_call(
|
||||
self._domain,
|
||||
self._action,
|
||||
tool_input.tool_args,
|
||||
context=llm_context.context,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
return {"success": True, "result": result}
|
||||
|
||||
|
||||
class ScriptTool(ActionTool):
|
||||
"""LLM Tool representing a Script."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
script_entity_id: str,
|
||||
) -> None:
|
||||
"""Init the class."""
|
||||
script_name = split_entity_id(script_entity_id)[1]
|
||||
|
||||
action = script_name
|
||||
entity_registry = er.async_get(hass)
|
||||
entity_entry = entity_registry.async_get(script_entity_id)
|
||||
if entity_entry and entity_entry.unique_id:
|
||||
action = entity_entry.unique_id
|
||||
|
||||
super().__init__(hass, DOMAIN, action)
|
||||
|
||||
self.name = script_name
|
||||
if self.name[0].isdigit():
|
||||
self.name = "_" + self.name
|
||||
@@ -0,0 +1,104 @@
|
||||
"""LLM tools for the todo integration."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import intent, llm
|
||||
from homeassistant.util.json import JsonObjectType
|
||||
|
||||
from .const import DOMAIN, TodoServices
|
||||
|
||||
|
||||
async def async_setup_tools(hass: HomeAssistant) -> None:
|
||||
"""Set up the todo LLM tools."""
|
||||
llm.async_register_tool_provider(hass, _todo_tools, apis={llm.LLM_API_ASSIST: None})
|
||||
|
||||
|
||||
@callback
|
||||
def _todo_tools(hass: HomeAssistant, llm_context: llm.LLMContext) -> llm.LLMTools:
|
||||
"""Return the todo tools for the exposed to-do lists."""
|
||||
if llm_context.assistant is None:
|
||||
return llm.LLMTools(tools=[])
|
||||
|
||||
exposed = llm.async_get_exposed_entities(
|
||||
hass, llm_context.assistant, include_state=False
|
||||
)
|
||||
names = []
|
||||
for info in exposed["entities"].values():
|
||||
if info["domain"] != DOMAIN:
|
||||
continue
|
||||
names.extend(info["names"].split(", "))
|
||||
if not names:
|
||||
return llm.LLMTools(tools=[])
|
||||
|
||||
return llm.LLMTools(tools=[TodoGetItemsTool(names)])
|
||||
|
||||
|
||||
class TodoGetItemsTool(llm.Tool):
|
||||
"""LLM Tool allowing querying a to-do list."""
|
||||
|
||||
name = "todo_get_items"
|
||||
description = (
|
||||
"Query a to-do list to find out what items are on it. "
|
||||
"Use this to answer questions like "
|
||||
"'What's on my task list?' or "
|
||||
"'Read my grocery list'. "
|
||||
"Filters items by status (needs_action, completed, all)."
|
||||
)
|
||||
|
||||
def __init__(self, todo_lists: list[str]) -> None:
|
||||
"""Init the get items tool."""
|
||||
self.parameters = vol.Schema(
|
||||
{
|
||||
vol.Required("todo_list"): vol.In(todo_lists),
|
||||
vol.Optional(
|
||||
"status",
|
||||
description=(
|
||||
"Filter returned items by status,"
|
||||
" by default returns incomplete"
|
||||
" items"
|
||||
),
|
||||
default="needs_action",
|
||||
): vol.In(["needs_action", "completed", "all"]),
|
||||
}
|
||||
)
|
||||
|
||||
async def async_call(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
tool_input: llm.ToolInput,
|
||||
llm_context: llm.LLMContext,
|
||||
) -> JsonObjectType:
|
||||
"""Query a to-do list."""
|
||||
data = self.parameters(tool_input.tool_args)
|
||||
result = intent.async_match_targets(
|
||||
hass,
|
||||
intent.MatchTargetsConstraints(
|
||||
name=data["todo_list"],
|
||||
domains=[DOMAIN],
|
||||
assistant=llm_context.assistant,
|
||||
),
|
||||
)
|
||||
if not result.is_match:
|
||||
return {"success": False, "error": "To-do list not found"}
|
||||
entity_id = result.states[0].entity_id
|
||||
service_data: dict[str, Any] = {"entity_id": entity_id}
|
||||
if status := data.get("status"):
|
||||
if status == "all":
|
||||
service_data["status"] = ["needs_action", "completed"]
|
||||
else:
|
||||
service_data["status"] = [status]
|
||||
service_result = await hass.services.async_call(
|
||||
DOMAIN,
|
||||
TodoServices.GET_ITEMS,
|
||||
service_data,
|
||||
context=llm_context.context,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
if not service_result:
|
||||
return {"success": False, "error": "To-do list not found"}
|
||||
items = cast(dict, service_result)[entity_id]["items"]
|
||||
return {"success": True, "result": items}
|
||||
+119
-627
@@ -3,35 +3,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field as dc_field
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from functools import cache, partial
|
||||
from operator import attrgetter
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import slugify as unicode_slug
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import UNSUPPORTED, convert
|
||||
|
||||
from homeassistant.components.calendar import (
|
||||
DOMAIN as CALENDAR_DOMAIN,
|
||||
SERVICE_GET_EVENTS,
|
||||
)
|
||||
from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
|
||||
from homeassistant.components.calendar import DOMAIN as CALENDAR_DOMAIN
|
||||
from homeassistant.components.homeassistant import async_should_expose
|
||||
from homeassistant.components.intent import async_device_supports_timers
|
||||
from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN
|
||||
from homeassistant.components.sensor import async_rounded_state
|
||||
from homeassistant.components.todo import DOMAIN as TODO_DOMAIN, TodoServices
|
||||
from homeassistant.components.weather import INTENT_GET_WEATHER
|
||||
from homeassistant.const import (
|
||||
ATTR_DOMAIN,
|
||||
ATTR_SERVICE,
|
||||
EVENT_HOMEASSISTANT_CLOSE,
|
||||
EVENT_SERVICE_REMOVED,
|
||||
)
|
||||
from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util import dt as dt_util, yaml as yaml_util
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
@@ -46,15 +31,9 @@ from . import (
|
||||
floor_registry as fr,
|
||||
intent,
|
||||
selector,
|
||||
service,
|
||||
)
|
||||
from .singleton import singleton
|
||||
|
||||
ACTION_PARAMETERS_CACHE: HassKey[
|
||||
dict[str, dict[str, tuple[str | None, vol.Schema]]]
|
||||
] = HassKey("llm_action_parameters_cache")
|
||||
|
||||
|
||||
LLM_API_ASSIST = "assist"
|
||||
|
||||
DATE_TIME_PROMPT = (
|
||||
@@ -72,43 +51,6 @@ NO_ENTITIES_PROMPT = (
|
||||
"to their voice assistant in Home Assistant."
|
||||
)
|
||||
|
||||
DEVICE_CONTROL_TOOL_USAGE_PROMPT = (
|
||||
"When controlling Home Assistant always call the intent tools. "
|
||||
"Use HassTurnOn to lock and HassTurnOff to unlock a lock. "
|
||||
"When controlling a device, prefer passing just name and domain. "
|
||||
"When controlling an area, prefer passing just area name and domain."
|
||||
)
|
||||
|
||||
DYNAMIC_CONTEXT_PROMPT = (
|
||||
"You ARE equipped to answer questions about the"
|
||||
" current state of\n"
|
||||
"the home using the `GetLiveContext` tool."
|
||||
" This is a primary function."
|
||||
" Do not state you lack the\n"
|
||||
"functionality if the question requires live data.\n"
|
||||
"If the user asks about device existence/type"
|
||||
' (e.g., "Do I have lights in the bedroom?"):'
|
||||
" Answer\n"
|
||||
"from the static context below.\n"
|
||||
"If the user asks about the CURRENT state, value,"
|
||||
' or mode (e.g., "Is the lock locked?",\n'
|
||||
'"Is the fan on?",'
|
||||
' "What mode is the thermostat in?",'
|
||||
' "What is the temperature outside?"):\n'
|
||||
" 1. Recognize this requires live data.\n"
|
||||
" 2. You MUST call `GetLiveContext`."
|
||||
" This tool will provide the needed real-time"
|
||||
" information (like temperature from the local"
|
||||
" weather, lock status, etc.).\n"
|
||||
" 3. Use the tool's response** to answer the"
|
||||
" user accurately"
|
||||
' (e.g., "The temperature outside is'
|
||||
' [value from tool].").\n'
|
||||
"For general knowledge questions not about the"
|
||||
" home: Answer truthfully from internal"
|
||||
" knowledge.\n"
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_render_no_api_prompt(hass: HomeAssistant) -> str:
|
||||
@@ -228,6 +170,95 @@ class Tool:
|
||||
return f"<{self.__class__.__name__} - {self.name}>"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMTools:
|
||||
"""Tools and an optional prompt fragment contributed by a provider."""
|
||||
|
||||
tools: list[Tool]
|
||||
prompt: str | None = None
|
||||
|
||||
|
||||
type LLMToolProvider = Callable[[HomeAssistant, LLMContext], LLMTools]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _ToolProviderRegistration:
|
||||
"""A registered tool provider and its importance within an API surface."""
|
||||
|
||||
provider: LLMToolProvider
|
||||
importance: int | None
|
||||
|
||||
|
||||
_TOOL_PROVIDERS: HassKey[dict[str, list[_ToolProviderRegistration]]] = HassKey(
|
||||
"llm_tool_providers"
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_tool_provider(
|
||||
hass: HomeAssistant,
|
||||
provider: LLMToolProvider,
|
||||
*,
|
||||
apis: dict[str, int | None],
|
||||
) -> Callable[[], None]:
|
||||
"""Register a provider that contributes tools to one or more LLM APIs.
|
||||
|
||||
The provider is evaluated per request with the ``LLMContext`` and returns
|
||||
the tools (and an optional prompt fragment) to expose. ``apis`` maps each
|
||||
API id the provider contributes to onto an importance; the importance is
|
||||
not used yet, pass ``None``.
|
||||
"""
|
||||
registrations = hass.data.setdefault(_TOOL_PROVIDERS, {})
|
||||
registered: list[tuple[str, _ToolProviderRegistration]] = []
|
||||
for api_id, importance in apis.items():
|
||||
registration = _ToolProviderRegistration(provider, importance)
|
||||
registrations.setdefault(api_id, []).append(registration)
|
||||
registered.append((api_id, registration))
|
||||
|
||||
@callback
|
||||
def unregister() -> None:
|
||||
"""Unregister the tool provider."""
|
||||
for api_id, registration in registered:
|
||||
registrations[api_id].remove(registration)
|
||||
|
||||
return unregister
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_tool(
|
||||
hass: HomeAssistant,
|
||||
tool: Tool,
|
||||
*,
|
||||
apis: dict[str, int | None],
|
||||
) -> Callable[[], None]:
|
||||
"""Register a single static tool with one or more LLM APIs."""
|
||||
|
||||
@callback
|
||||
def _provider(_hass: HomeAssistant, _llm_context: LLMContext) -> LLMTools:
|
||||
return LLMTools(tools=[tool])
|
||||
|
||||
return async_register_tool_provider(hass, _provider, apis=apis)
|
||||
|
||||
|
||||
@callback
|
||||
def _async_get_registered_tools(
|
||||
hass: HomeAssistant, api_id: str, llm_context: LLMContext
|
||||
) -> LLMTools:
|
||||
"""Return the tools and merged prompt from all providers for an API."""
|
||||
registrations = hass.data.get(_TOOL_PROVIDERS)
|
||||
if not registrations or not (records := registrations.get(api_id)):
|
||||
return LLMTools(tools=[])
|
||||
|
||||
tools: list[Tool] = []
|
||||
prompts: list[str] = []
|
||||
for registration in records:
|
||||
result = registration.provider(hass, llm_context)
|
||||
tools.extend(result.tools)
|
||||
if result.prompt:
|
||||
prompts.append(result.prompt)
|
||||
return LLMTools(tools=tools, prompt="\n".join(prompts) if prompts else None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIInstance:
|
||||
"""Instance of an API to be used by an LLM."""
|
||||
@@ -457,19 +488,6 @@ class MergedAPI(API):
|
||||
class AssistAPI(API):
|
||||
"""API exposing Assist API to LLMs."""
|
||||
|
||||
IGNORE_INTENTS = {
|
||||
intent.INTENT_GET_TEMPERATURE,
|
||||
INTENT_GET_WEATHER,
|
||||
INTENT_OPEN_COVER, # deprecated
|
||||
INTENT_CLOSE_COVER, # deprecated
|
||||
intent.INTENT_GET_STATE,
|
||||
intent.INTENT_NEVERMIND,
|
||||
intent.INTENT_TOGGLE,
|
||||
intent.INTENT_GET_CURRENT_DATE,
|
||||
intent.INTENT_GET_CURRENT_TIME,
|
||||
intent.INTENT_RESPOND,
|
||||
}
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Init the class."""
|
||||
super().__init__(
|
||||
@@ -477,9 +495,6 @@ class AssistAPI(API):
|
||||
id=LLM_API_ASSIST,
|
||||
name="Assist",
|
||||
)
|
||||
self.cached_slugify = cache(
|
||||
partial(unicode_slug.slugify, separator="_", lowercase=False)
|
||||
)
|
||||
|
||||
async def async_get_api_instance(self, llm_context: LLMContext) -> APIInstance:
|
||||
"""Return the instance of the API."""
|
||||
@@ -490,11 +505,20 @@ class AssistAPI(API):
|
||||
else:
|
||||
exposed_entities = None
|
||||
|
||||
registered = _async_get_registered_tools(self.hass, LLM_API_ASSIST, llm_context)
|
||||
|
||||
api_prompt = self._async_get_api_prompt(llm_context, exposed_entities)
|
||||
if registered.prompt:
|
||||
api_prompt = f"{api_prompt}\n{registered.prompt}"
|
||||
|
||||
tools = self._async_get_tools(llm_context, exposed_entities)
|
||||
tools.extend(registered.tools)
|
||||
|
||||
return APIInstance(
|
||||
api=self,
|
||||
api_prompt=self._async_get_api_prompt(llm_context, exposed_entities),
|
||||
api_prompt=api_prompt,
|
||||
llm_context=llm_context,
|
||||
tools=self._async_get_tools(llm_context, exposed_entities),
|
||||
tools=tools,
|
||||
custom_serializer=selector_serializer,
|
||||
)
|
||||
|
||||
@@ -507,24 +531,13 @@ class AssistAPI(API):
|
||||
|
||||
# Collect all parts, filtering out any None values
|
||||
prompt_parts = [
|
||||
DEVICE_CONTROL_TOOL_USAGE_PROMPT,
|
||||
DYNAMIC_CONTEXT_PROMPT,
|
||||
*self._async_get_exposed_entities_prompt(exposed_entities),
|
||||
self._async_get_voice_satellite_area_prompt(llm_context),
|
||||
self._async_get_no_timer_prompt(llm_context),
|
||||
]
|
||||
|
||||
# Filter out None and empty strings before joining
|
||||
return "\n".join([part for part in prompt_parts if part])
|
||||
|
||||
@callback
|
||||
def _async_get_no_timer_prompt(self, llm_context: LLMContext) -> str | None:
|
||||
if not llm_context.device_id or not async_device_supports_timers(
|
||||
self.hass, llm_context.device_id
|
||||
):
|
||||
return "This device is not able to start timers."
|
||||
return None
|
||||
|
||||
@callback
|
||||
def _async_get_voice_satellite_area_prompt(self, llm_context: LLMContext) -> str:
|
||||
"""Return the area prompt for the voice satellite."""
|
||||
@@ -579,70 +592,7 @@ class AssistAPI(API):
|
||||
self, llm_context: LLMContext, exposed_entities: dict | None
|
||||
) -> list[Tool]:
|
||||
"""Return a list of LLM tools."""
|
||||
ignore_intents = self.IGNORE_INTENTS
|
||||
if not llm_context.device_id or not async_device_supports_timers(
|
||||
self.hass, llm_context.device_id
|
||||
):
|
||||
ignore_intents = ignore_intents | {
|
||||
intent.INTENT_START_TIMER,
|
||||
intent.INTENT_CANCEL_TIMER,
|
||||
intent.INTENT_INCREASE_TIMER,
|
||||
intent.INTENT_DECREASE_TIMER,
|
||||
intent.INTENT_PAUSE_TIMER,
|
||||
intent.INTENT_UNPAUSE_TIMER,
|
||||
intent.INTENT_TIMER_STATUS,
|
||||
}
|
||||
|
||||
intent_handlers = [
|
||||
intent_handler
|
||||
for intent_handler in intent.async_get(self.hass)
|
||||
if intent_handler.intent_type not in ignore_intents
|
||||
]
|
||||
|
||||
exposed_domains: set[str] | None = None
|
||||
if exposed_entities is not None:
|
||||
exposed_domains = {
|
||||
info["domain"] for info in exposed_entities["entities"].values()
|
||||
}
|
||||
|
||||
intent_handlers = [
|
||||
intent_handler
|
||||
for intent_handler in intent_handlers
|
||||
if intent_handler.platforms is None
|
||||
or intent_handler.platforms & exposed_domains
|
||||
]
|
||||
|
||||
tools: list[Tool] = [
|
||||
IntentTool(self.cached_slugify(intent_handler.intent_type), intent_handler)
|
||||
for intent_handler in intent_handlers
|
||||
]
|
||||
|
||||
tools.append(GetDateTimeTool())
|
||||
|
||||
if exposed_entities:
|
||||
if exposed_entities[CALENDAR_DOMAIN]:
|
||||
names = []
|
||||
for info in exposed_entities[CALENDAR_DOMAIN].values():
|
||||
names.extend(info["names"].split(", "))
|
||||
tools.append(CalendarGetEventsTool(names))
|
||||
|
||||
if exposed_domains is not None and TODO_DOMAIN in exposed_domains:
|
||||
names = []
|
||||
for info in exposed_entities["entities"].values():
|
||||
if info["domain"] != TODO_DOMAIN:
|
||||
continue
|
||||
names.extend(info["names"].split(", "))
|
||||
tools.append(TodoGetItemsTool(names))
|
||||
|
||||
tools.extend(
|
||||
ScriptTool(self.hass, script_entity_id)
|
||||
for script_entity_id in exposed_entities[SCRIPT_DOMAIN]
|
||||
)
|
||||
|
||||
if exposed_domains:
|
||||
tools.append(GetLiveContextTool())
|
||||
|
||||
return tools
|
||||
return []
|
||||
|
||||
|
||||
def _get_exposed_entities(
|
||||
@@ -755,6 +705,21 @@ def _get_exposed_entities(
|
||||
return data
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_exposed_entities(
|
||||
hass: HomeAssistant,
|
||||
assistant: str,
|
||||
*,
|
||||
include_state: bool = False,
|
||||
) -> dict[str, dict[str, dict[str, Any]]]:
|
||||
"""Get exposed entities for a tool provider.
|
||||
|
||||
Splits out calendars and scripts. Tool providers can use this to reproduce
|
||||
the exact entity names that AssistAPI feeds to its built-in tools.
|
||||
"""
|
||||
return _get_exposed_entities(hass, assistant, include_state=include_state)
|
||||
|
||||
|
||||
def selector_serializer(schema: Any) -> Any: # noqa: C901
|
||||
"""Convert selectors into OpenAPI schema."""
|
||||
if not isinstance(schema, selector.Selector):
|
||||
@@ -893,476 +858,3 @@ def selector_serializer(schema: Any) -> Any: # noqa: C901
|
||||
return {"type": "array", "items": {"type": "string"}}
|
||||
|
||||
return {"type": "string"}
|
||||
|
||||
|
||||
def _get_cached_action_parameters(
|
||||
hass: HomeAssistant, domain: str, action: str
|
||||
) -> tuple[str | None, vol.Schema]:
|
||||
"""Get action description and schema."""
|
||||
description = None
|
||||
parameters = vol.Schema({})
|
||||
|
||||
parameters_cache = hass.data.get(ACTION_PARAMETERS_CACHE)
|
||||
|
||||
if parameters_cache is None:
|
||||
parameters_cache = hass.data[ACTION_PARAMETERS_CACHE] = {}
|
||||
|
||||
@callback
|
||||
def clear_cache(event: Event) -> None:
|
||||
"""Clear action parameter cache on action removal."""
|
||||
if (
|
||||
event.data[ATTR_DOMAIN] in parameters_cache
|
||||
and event.data[ATTR_SERVICE]
|
||||
in parameters_cache[event.data[ATTR_DOMAIN]]
|
||||
):
|
||||
parameters_cache[event.data[ATTR_DOMAIN]].pop(event.data[ATTR_SERVICE])
|
||||
|
||||
cancel = hass.bus.async_listen(EVENT_SERVICE_REMOVED, clear_cache)
|
||||
|
||||
@callback
|
||||
def on_homeassistant_close(event: Event) -> None:
|
||||
"""Cleanup."""
|
||||
cancel()
|
||||
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, on_homeassistant_close)
|
||||
|
||||
if domain in parameters_cache and action in parameters_cache[domain]:
|
||||
return parameters_cache[domain][action]
|
||||
|
||||
if action_desc := service.async_get_cached_service_description(
|
||||
hass, domain, action
|
||||
):
|
||||
description = action_desc.get("description")
|
||||
schema: dict[vol.Marker, Any] = {}
|
||||
fields = action_desc.get("fields", {})
|
||||
|
||||
for field, config in fields.items():
|
||||
field_description = config.get("description")
|
||||
if not field_description:
|
||||
field_description = config.get("name")
|
||||
key: vol.Marker
|
||||
if config.get("required"):
|
||||
key = vol.Required(field, description=field_description)
|
||||
else:
|
||||
key = vol.Optional(field, description=field_description)
|
||||
if "selector" in config:
|
||||
schema[key] = selector.selector(config["selector"])
|
||||
else:
|
||||
schema[key] = cv.string
|
||||
|
||||
parameters = vol.Schema(schema)
|
||||
|
||||
if domain == SCRIPT_DOMAIN:
|
||||
entity_registry = er.async_get(hass)
|
||||
if (
|
||||
entity_id := entity_registry.async_get_entity_id(domain, domain, action)
|
||||
) is not None and (
|
||||
entity_entry := entity_registry.async_get(entity_id)
|
||||
) is not None:
|
||||
aliases = er.async_get_entity_aliases(hass, entity_entry)
|
||||
if aliases:
|
||||
if description:
|
||||
description = description + ". Aliases: " + str(sorted(aliases))
|
||||
else:
|
||||
description = "Aliases: " + str(sorted(aliases))
|
||||
|
||||
parameters_cache.setdefault(domain, {})[action] = (description, parameters)
|
||||
|
||||
return description, parameters
|
||||
|
||||
|
||||
class ActionTool(Tool):
|
||||
"""LLM Tool representing an action."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
action: str,
|
||||
) -> None:
|
||||
"""Init the class."""
|
||||
self._domain = domain
|
||||
self._action = action
|
||||
self.name = f"{domain}__{action}"
|
||||
# Note: _get_cached_action_parameters only works for services which
|
||||
# add their description directly to the service description cache.
|
||||
# This is not the case for most services, but it is for scripts.
|
||||
# If we want to use `ActionTool` for services other than scripts, we
|
||||
# need to add a coroutine function to fetch the non-cached description
|
||||
# and schema.
|
||||
self.description, self.parameters = _get_cached_action_parameters(
|
||||
hass, domain, action
|
||||
)
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||
) -> JsonObjectType:
|
||||
"""Call the action."""
|
||||
|
||||
for field, validator in self.parameters.schema.items():
|
||||
if field not in tool_input.tool_args:
|
||||
continue
|
||||
if isinstance(validator, selector.AreaSelector):
|
||||
area_reg = ar.async_get(hass)
|
||||
if validator.config.get("multiple"):
|
||||
areas: list[ar.AreaEntry] = []
|
||||
for area in tool_input.tool_args[field]:
|
||||
areas.extend(intent.find_areas(area, area_reg))
|
||||
tool_input.tool_args[field] = list({area.id for area in areas})
|
||||
else:
|
||||
area = tool_input.tool_args[field]
|
||||
area = list(intent.find_areas(area, area_reg))[0].id
|
||||
tool_input.tool_args[field] = area
|
||||
|
||||
elif isinstance(validator, selector.FloorSelector):
|
||||
floor_reg = fr.async_get(hass)
|
||||
if validator.config.get("multiple"):
|
||||
floors: list[fr.FloorEntry] = []
|
||||
for floor in tool_input.tool_args[field]:
|
||||
floors.extend(intent.find_floors(floor, floor_reg))
|
||||
tool_input.tool_args[field] = list(
|
||||
{floor.floor_id for floor in floors}
|
||||
)
|
||||
else:
|
||||
floor = tool_input.tool_args[field]
|
||||
floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
|
||||
tool_input.tool_args[field] = floor
|
||||
|
||||
result = await hass.services.async_call(
|
||||
self._domain,
|
||||
self._action,
|
||||
tool_input.tool_args,
|
||||
context=llm_context.context,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
return {"success": True, "result": result}
|
||||
|
||||
|
||||
class ScriptTool(ActionTool):
|
||||
"""LLM Tool representing a Script."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
script_entity_id: str,
|
||||
) -> None:
|
||||
"""Init the class."""
|
||||
script_name = split_entity_id(script_entity_id)[1]
|
||||
|
||||
action = script_name
|
||||
entity_registry = er.async_get(hass)
|
||||
entity_entry = entity_registry.async_get(script_entity_id)
|
||||
if entity_entry and entity_entry.unique_id:
|
||||
action = entity_entry.unique_id
|
||||
|
||||
super().__init__(hass, SCRIPT_DOMAIN, action)
|
||||
|
||||
self.name = script_name
|
||||
if self.name[0].isdigit():
|
||||
self.name = "_" + self.name
|
||||
|
||||
|
||||
class CalendarGetEventsTool(Tool):
|
||||
"""LLM Tool allowing querying a calendar."""
|
||||
|
||||
name = "calendar_get_events"
|
||||
description = (
|
||||
"Get events from a calendar. "
|
||||
"When asked if something happens, search the whole week. "
|
||||
"Results are RFC 5545 which means 'end' is exclusive."
|
||||
)
|
||||
|
||||
def __init__(self, calendars: list[str]) -> None:
|
||||
"""Init the get events tool."""
|
||||
self.parameters = vol.Schema(
|
||||
{
|
||||
vol.Required("calendar"): vol.In(calendars),
|
||||
vol.Required("range"): vol.In(["today", "week"]),
|
||||
}
|
||||
)
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||
) -> JsonObjectType:
|
||||
"""Query a calendar."""
|
||||
data = self.parameters(tool_input.tool_args)
|
||||
result = intent.async_match_targets(
|
||||
hass,
|
||||
intent.MatchTargetsConstraints(
|
||||
name=data["calendar"],
|
||||
domains=[CALENDAR_DOMAIN],
|
||||
assistant=llm_context.assistant,
|
||||
),
|
||||
)
|
||||
if not result.is_match:
|
||||
return {"success": False, "error": "Calendar not found"}
|
||||
|
||||
entity_id = result.states[0].entity_id
|
||||
if data["range"] == "today":
|
||||
start = dt_util.now()
|
||||
end = dt_util.start_of_local_day() + timedelta(days=1)
|
||||
elif data["range"] == "week":
|
||||
start = dt_util.now()
|
||||
end = dt_util.start_of_local_day() + timedelta(days=7)
|
||||
|
||||
service_data = {
|
||||
"entity_id": entity_id,
|
||||
"start_date_time": start.isoformat(),
|
||||
"end_date_time": end.isoformat(),
|
||||
}
|
||||
|
||||
service_result = await hass.services.async_call(
|
||||
CALENDAR_DOMAIN,
|
||||
SERVICE_GET_EVENTS,
|
||||
service_data,
|
||||
context=llm_context.context,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
events = [
|
||||
event if "T" in event["start"] else {**event, "all_day": True}
|
||||
for event in cast(dict, service_result)[entity_id]["events"]
|
||||
]
|
||||
|
||||
return {"success": True, "result": events}
|
||||
|
||||
|
||||
class TodoGetItemsTool(Tool):
|
||||
"""LLM Tool allowing querying a to-do list."""
|
||||
|
||||
name = "todo_get_items"
|
||||
description = (
|
||||
"Query a to-do list to find out what items are on it. "
|
||||
"Use this to answer questions like "
|
||||
"'What's on my task list?' or "
|
||||
"'Read my grocery list'. "
|
||||
"Filters items by status (needs_action, completed, all)."
|
||||
)
|
||||
|
||||
def __init__(self, todo_lists: list[str]) -> None:
|
||||
"""Init the get items tool."""
|
||||
self.parameters = vol.Schema(
|
||||
{
|
||||
vol.Required("todo_list"): vol.In(todo_lists),
|
||||
vol.Optional(
|
||||
"status",
|
||||
description=(
|
||||
"Filter returned items by status,"
|
||||
" by default returns incomplete"
|
||||
" items"
|
||||
),
|
||||
default="needs_action",
|
||||
): vol.In(["needs_action", "completed", "all"]),
|
||||
}
|
||||
)
|
||||
|
||||
async def async_call(
|
||||
self, hass: HomeAssistant, tool_input: ToolInput, llm_context: LLMContext
|
||||
) -> JsonObjectType:
|
||||
"""Query a to-do list."""
|
||||
data = self.parameters(tool_input.tool_args)
|
||||
result = intent.async_match_targets(
|
||||
hass,
|
||||
intent.MatchTargetsConstraints(
|
||||
name=data["todo_list"],
|
||||
domains=[TODO_DOMAIN],
|
||||
assistant=llm_context.assistant,
|
||||
),
|
||||
)
|
||||
if not result.is_match:
|
||||
return {"success": False, "error": "To-do list not found"}
|
||||
entity_id = result.states[0].entity_id
|
||||
service_data: dict[str, Any] = {"entity_id": entity_id}
|
||||
if status := data.get("status"):
|
||||
if status == "all":
|
||||
service_data["status"] = ["needs_action", "completed"]
|
||||
else:
|
||||
service_data["status"] = [status]
|
||||
service_result = await hass.services.async_call(
|
||||
TODO_DOMAIN,
|
||||
TodoServices.GET_ITEMS,
|
||||
service_data,
|
||||
context=llm_context.context,
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
if not service_result:
|
||||
return {"success": False, "error": "To-do list not found"}
|
||||
items = cast(dict, service_result)[entity_id]["items"]
|
||||
return {"success": True, "result": items}
|
||||
|
||||
|
||||
def _live_context_match_error(
|
||||
match_result: intent.MatchTargetsResult,
|
||||
name_filter: str | None,
|
||||
area_filter: str | None,
|
||||
domain_filter: list[str] | None,
|
||||
) -> str:
|
||||
"""Build an actionable error message for a failed GetLiveContext match."""
|
||||
reason = match_result.no_match_reason
|
||||
if reason is intent.MatchFailedReason.INVALID_AREA:
|
||||
return f"Area '{match_result.no_match_name}' does not exist"
|
||||
if reason is intent.MatchFailedReason.NAME:
|
||||
return f"No exposed entities matched name '{name_filter}'"
|
||||
if reason is intent.MatchFailedReason.AREA:
|
||||
return f"No exposed entities found in area '{area_filter}'"
|
||||
if reason is intent.MatchFailedReason.DOMAIN:
|
||||
domains = ", ".join(domain_filter) if domain_filter else ""
|
||||
return f"No exposed entities found in domain(s): {domains}"
|
||||
return "No entities matched the provided filter"
|
||||
|
||||
|
||||
class GetLiveContextTool(Tool):
|
||||
"""Tool for getting the current state of exposed entities.
|
||||
|
||||
This returns state for all entities that have been exposed to
|
||||
the assistant. This is different than the GetState intent, which
|
||||
returns state for entities based on intent parameters.
|
||||
"""
|
||||
|
||||
name = "GetLiveContext"
|
||||
description = (
|
||||
"Provides real-time information about the"
|
||||
" CURRENT state, value, or mode of devices,"
|
||||
" sensors, entities, or areas. "
|
||||
"Use this tool for: "
|
||||
"1. Answering questions about current"
|
||||
" conditions (e.g., 'Is the light on?'). "
|
||||
"2. As the first step in conditional actions"
|
||||
" (e.g., 'If the weather is rainy, turn off"
|
||||
" sprinklers' requires checking the weather"
|
||||
" first). "
|
||||
"You may filter for devices by name, domain,"
|
||||
" and area, including combining those"
|
||||
" filters. "
|
||||
"Prefer filtering by domain when searching"
|
||||
" for multiple devices of the same type."
|
||||
)
|
||||
parameters = vol.Schema(
|
||||
{
|
||||
vol.Optional(
|
||||
"name",
|
||||
description="Filter entities by name or alias (case-insensitive).",
|
||||
): cv.string,
|
||||
vol.Optional(
|
||||
"domain",
|
||||
description=(
|
||||
"Filter entities by domain"
|
||||
" (e.g. 'light', 'sensor')."
|
||||
" Accepts a single domain or a list."
|
||||
),
|
||||
): vol.Any(cv.string, [cv.string]),
|
||||
vol.Optional(
|
||||
"area",
|
||||
description="Filter entities by area name or alias (case-insensitive).",
|
||||
): cv.string,
|
||||
}
|
||||
)
|
||||
|
||||
async def async_call(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
tool_input: ToolInput,
|
||||
llm_context: LLMContext,
|
||||
) -> JsonObjectType:
|
||||
"""Get the current state of exposed entities."""
|
||||
if llm_context.assistant is None:
|
||||
# Note this doesn't happen in practice since this tool won't be
|
||||
# exposed if no assistant is configured.
|
||||
return {"success": False, "error": "No assistant configured"}
|
||||
|
||||
args = self.parameters(tool_input.tool_args)
|
||||
exposed_entities = _get_exposed_entities(hass, llm_context.assistant)
|
||||
|
||||
if not exposed_entities["entities"]:
|
||||
return {"success": False, "error": NO_ENTITIES_PROMPT}
|
||||
|
||||
name_filter = args.get("name")
|
||||
area_filter = args.get("area")
|
||||
domain_filter = args.get("domain")
|
||||
|
||||
if isinstance(domain_filter, str):
|
||||
domain_filter = [domain_filter]
|
||||
|
||||
if domain_filter is not None:
|
||||
domain_filter = [
|
||||
normalized_domain
|
||||
for domain in domain_filter
|
||||
if (normalized_domain := domain.strip().lower())
|
||||
]
|
||||
|
||||
if name_filter or area_filter or domain_filter:
|
||||
exposed_states = [
|
||||
state
|
||||
for entity_id in exposed_entities["entities"]
|
||||
if (state := hass.states.get(entity_id)) is not None
|
||||
]
|
||||
match_result = intent.async_match_targets(
|
||||
hass,
|
||||
intent.MatchTargetsConstraints(
|
||||
name=name_filter,
|
||||
area_name=area_filter,
|
||||
domains=domain_filter,
|
||||
# This tool only returns context, so multiple entities
|
||||
# sharing a name (e.g. "AC" in two areas) should all be
|
||||
# returned rather than failing as an ambiguous match.
|
||||
allow_duplicate_names=True,
|
||||
),
|
||||
states=exposed_states,
|
||||
)
|
||||
|
||||
if not match_result.is_match:
|
||||
return {
|
||||
"success": False,
|
||||
"error": _live_context_match_error(
|
||||
match_result, name_filter, area_filter, domain_filter
|
||||
),
|
||||
}
|
||||
|
||||
matched_ids = {state.entity_id for state in match_result.states}
|
||||
entities = [
|
||||
info
|
||||
for entity_id, info in exposed_entities["entities"].items()
|
||||
if entity_id in matched_ids
|
||||
]
|
||||
else:
|
||||
entities = list(exposed_entities["entities"].values())
|
||||
|
||||
prompt = [
|
||||
"Live Context: An overview of the areas"
|
||||
" and the devices in this smart home:",
|
||||
yaml_util.dump(entities),
|
||||
]
|
||||
return {
|
||||
"success": True,
|
||||
"result": "\n".join(prompt),
|
||||
}
|
||||
|
||||
|
||||
class GetDateTimeTool(Tool):
|
||||
"""Tool for getting the current date and time."""
|
||||
|
||||
name = "GetDateTime"
|
||||
description = "Provides the current date and time."
|
||||
|
||||
async def async_call(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
tool_input: ToolInput,
|
||||
llm_context: LLMContext,
|
||||
) -> JsonObjectType:
|
||||
"""Get the current date and time."""
|
||||
now = dt_util.now()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": {
|
||||
"date": now.strftime("%Y-%m-%d"),
|
||||
"time": now.strftime("%H:%M:%S"),
|
||||
"timezone": now.strftime("%Z"),
|
||||
"weekday": now.strftime("%A"),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2101,6 +2101,7 @@ NO_QUALITY_SCALE = [
|
||||
"intent_script",
|
||||
"intent",
|
||||
"labs",
|
||||
"llm",
|
||||
"logbook",
|
||||
"logger",
|
||||
"lovelace",
|
||||
|
||||
@@ -843,9 +843,11 @@ async def test_redacted_thinking(
|
||||
assert chat_log.content[1:] == snapshot
|
||||
|
||||
|
||||
@patch("homeassistant.helpers.llm._async_get_registered_tools")
|
||||
@patch("homeassistant.components.anthropic.entity.llm.AssistAPI._async_get_tools")
|
||||
async def test_extended_thinking_tool_call(
|
||||
mock_get_tools,
|
||||
mock_registered_tools,
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
@@ -875,6 +877,9 @@ async def test_extended_thinking_tool_call(
|
||||
mock_tool.async_call.return_value = "Test response"
|
||||
|
||||
mock_get_tools.return_value = [mock_tool]
|
||||
# The Assist API now also sources tools from the llm integration's registry;
|
||||
# empty it so this test's tool set is just the mock tool.
|
||||
mock_registered_tools.return_value = llm.LLMTools(tools=[])
|
||||
|
||||
mock_create_stream.return_value = [
|
||||
(
|
||||
|
||||
@@ -28,6 +28,7 @@ from homeassistant.components.conversation.chat_log import (
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import chat_session, llm
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from tests.common import async_fire_time_changed
|
||||
@@ -168,6 +169,11 @@ async def test_dynamic_time_injection(
|
||||
hass: HomeAssistant, mock_conversation_input: ConversationInput
|
||||
) -> None:
|
||||
"""Test that dynamic time injection works correctly."""
|
||||
# The Assist API provides the GetDateTime tool via the llm integration, which
|
||||
# suppresses the dynamic time prompt (case 3 below).
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
class MyAPI(llm.API):
|
||||
"""Test API."""
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Tests for the LLM integration."""
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Tests for the LLM integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import mock_platform
|
||||
|
||||
|
||||
async def test_setup(hass: HomeAssistant) -> None:
|
||||
"""Test the integration sets up."""
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
|
||||
|
||||
async def test_tool_platform_discovery(hass: HomeAssistant) -> None:
|
||||
"""Test that an integration's llm tools platform is set up."""
|
||||
platform = Mock(async_setup_tools=AsyncMock())
|
||||
mock_platform(hass, "test.llm", platform)
|
||||
hass.config.components.add("test")
|
||||
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
platform.async_setup_tools.assert_awaited_once_with(hass)
|
||||
@@ -0,0 +1,566 @@
|
||||
# serializer version: 1
|
||||
# name: test_assist_api_snapshot[prompt]
|
||||
list([
|
||||
' 1. Recognize this requires live data.',
|
||||
' 2. You MUST call `GetLiveContext`. This tool will provide the needed real-time information (like temperature from the local weather, lock status, etc.).',
|
||||
' 3. Use the tool\'s response** to answer the user accurately (e.g., "The temperature outside is [value from tool].").',
|
||||
' domain: light',
|
||||
' domain: todo',
|
||||
'"Is the fan on?", "What mode is the thermostat in?", "What is the temperature outside?"):',
|
||||
'- names: Kitchen',
|
||||
'- names: Shopping',
|
||||
'For general knowledge questions not about the home: Answer truthfully from internal knowledge.',
|
||||
'If the user asks about device existence/type (e.g., "Do I have lights in the bedroom?"): Answer',
|
||||
'If the user asks about the CURRENT state, value, or mode (e.g., "Is the lock locked?",',
|
||||
'Static Context: An overview of the areas and the devices in this smart home:',
|
||||
'When controlling Home Assistant always call the intent tools. Use HassTurnOn to lock and HassTurnOff to unlock a lock. When controlling a device, prefer passing just name and domain. When controlling an area, prefer passing just area name and domain.',
|
||||
'You ARE equipped to answer questions about the current state of',
|
||||
"You are in area Test Area and all generic commands like 'turn on the lights' should target this area.",
|
||||
'from the static context below.',
|
||||
'functionality if the question requires live data.',
|
||||
'the home using the `GetLiveContext` tool. This is a primary function. Do not state you lack the',
|
||||
])
|
||||
# ---
|
||||
# name: test_assist_api_snapshot[tools]
|
||||
dict({
|
||||
'GetDateTime': dict({
|
||||
'description': 'Provides the current date and time.',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
}),
|
||||
'required': list([
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'GetLiveContext': dict({
|
||||
'description': "Provides real-time information about the CURRENT state, value, or mode of devices, sensors, entities, or areas. Use this tool for: 1. Answering questions about current conditions (e.g., 'Is the light on?'). 2. As the first step in conditional actions (e.g., 'If the weather is rainy, turn off sprinklers' requires checking the weather first). You may filter for devices by name, domain, and area, including combining those filters. Prefer filtering by domain when searching for multiple devices of the same type.",
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'area': dict({
|
||||
'description': 'Filter entities by area name or alias (case-insensitive).',
|
||||
'type': 'string',
|
||||
}),
|
||||
'domain': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
}),
|
||||
dict({
|
||||
'items': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': 'array',
|
||||
}),
|
||||
]),
|
||||
'description': "Filter entities by domain (e.g. 'light', 'sensor'). Accepts a single domain or a list.",
|
||||
}),
|
||||
'name': dict({
|
||||
'description': 'Filter entities by name or alias (case-insensitive).',
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'HassCancelAllTimers': dict({
|
||||
'description': 'Cancels all timers',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'area': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'HassCancelTimer': dict({
|
||||
'description': 'Cancels a timer',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'area': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'start_hours': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_minutes': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_seconds': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'HassDecreaseTimer': dict({
|
||||
'description': 'Removes time from a timer',
|
||||
'parameters': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'required': list([
|
||||
'hours',
|
||||
]),
|
||||
}),
|
||||
dict({
|
||||
'required': list([
|
||||
'minutes',
|
||||
]),
|
||||
}),
|
||||
dict({
|
||||
'required': list([
|
||||
'seconds',
|
||||
]),
|
||||
}),
|
||||
]),
|
||||
'properties': dict({
|
||||
'area': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'hours': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'minutes': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'name': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'seconds': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_hours': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_minutes': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_seconds': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'HassIncreaseTimer': dict({
|
||||
'description': 'Adds more time to a timer',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'area': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'hours': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'minutes': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'name': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'seconds': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_hours': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_minutes': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_seconds': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'HassListAddItem': dict({
|
||||
'description': 'Add item to a todo list',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'item': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'item',
|
||||
'name',
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'HassListCompleteItem': dict({
|
||||
'description': 'Complete item on a todo list',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'item': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'item',
|
||||
'name',
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'HassListRemoveItem': dict({
|
||||
'description': 'Remove one or more items from a todo list',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'item': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'item',
|
||||
'name',
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'HassPauseTimer': dict({
|
||||
'description': 'Pauses a running timer',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'area': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'start_hours': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_minutes': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_seconds': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'HassStartTimer': dict({
|
||||
'description': 'Starts a new timer',
|
||||
'parameters': dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'required': list([
|
||||
'hours',
|
||||
]),
|
||||
}),
|
||||
dict({
|
||||
'required': list([
|
||||
'minutes',
|
||||
]),
|
||||
}),
|
||||
dict({
|
||||
'required': list([
|
||||
'seconds',
|
||||
]),
|
||||
}),
|
||||
]),
|
||||
'properties': dict({
|
||||
'conversation_command': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'hours': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'minutes': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'name': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'seconds': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'HassTimerStatus': dict({
|
||||
'description': 'Reports the current status of timers',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'area': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'start_hours': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_minutes': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_seconds': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'HassTurnOff': dict({
|
||||
'description': "Turns off/closes a device or entity. For locks, this performs an 'unlock' action. Use for requests like 'turn off', 'deactivate', 'disable', or 'unlock'.",
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'area': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'device_class': dict({
|
||||
'items': dict({
|
||||
'enum': list([
|
||||
'awning',
|
||||
'blind',
|
||||
'curtain',
|
||||
'damper',
|
||||
'door',
|
||||
'garage',
|
||||
'gas',
|
||||
'gate',
|
||||
'identify',
|
||||
'outlet',
|
||||
'projector',
|
||||
'receiver',
|
||||
'restart',
|
||||
'shade',
|
||||
'shutter',
|
||||
'speaker',
|
||||
'switch',
|
||||
'tv',
|
||||
'update',
|
||||
'water',
|
||||
'window',
|
||||
]),
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': 'array',
|
||||
}),
|
||||
'domain': dict({
|
||||
'items': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': 'array',
|
||||
}),
|
||||
'floor': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'HassTurnOn': dict({
|
||||
'description': "Turns on/opens/presses a device or entity. For locks, this performs a 'lock' action. Use for requests like 'turn on', 'activate', 'enable', or 'lock'.",
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'area': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'device_class': dict({
|
||||
'items': dict({
|
||||
'enum': list([
|
||||
'awning',
|
||||
'blind',
|
||||
'curtain',
|
||||
'damper',
|
||||
'door',
|
||||
'garage',
|
||||
'gas',
|
||||
'gate',
|
||||
'identify',
|
||||
'outlet',
|
||||
'projector',
|
||||
'receiver',
|
||||
'restart',
|
||||
'shade',
|
||||
'shutter',
|
||||
'speaker',
|
||||
'switch',
|
||||
'tv',
|
||||
'update',
|
||||
'water',
|
||||
'window',
|
||||
]),
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': 'array',
|
||||
}),
|
||||
'domain': dict({
|
||||
'items': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'type': 'array',
|
||||
}),
|
||||
'floor': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'HassUnpauseTimer': dict({
|
||||
'description': 'Resumes a paused timer',
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'area': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'name': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
'start_hours': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_minutes': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
'start_seconds': dict({
|
||||
'minimum': 0,
|
||||
'type': 'integer',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'calendar_get_events': dict({
|
||||
'description': "Get events from a calendar. When asked if something happens, search the whole week. Results are RFC 5545 which means 'end' is exclusive.",
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'calendar': dict({
|
||||
'enum': list([
|
||||
'Personal',
|
||||
]),
|
||||
'type': 'string',
|
||||
}),
|
||||
'range': dict({
|
||||
'enum': list([
|
||||
'today',
|
||||
'week',
|
||||
]),
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'calendar',
|
||||
'range',
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'test_script': dict({
|
||||
'description': "This is a test script. Aliases: ['test_script']",
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'beer': dict({
|
||||
'description': 'Number of beers',
|
||||
'type': 'string',
|
||||
}),
|
||||
'wine': dict({
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
'todo_get_items': dict({
|
||||
'description': "Query a to-do list to find out what items are on it. Use this to answer questions like 'What's on my task list?' or 'Read my grocery list'. Filters items by status (needs_action, completed, all).",
|
||||
'parameters': dict({
|
||||
'properties': dict({
|
||||
'status': dict({
|
||||
'default': 'needs_action',
|
||||
'description': 'Filter returned items by status, by default returns incomplete items',
|
||||
'enum': list([
|
||||
'all',
|
||||
'completed',
|
||||
'needs_action',
|
||||
]),
|
||||
'type': 'string',
|
||||
}),
|
||||
'todo_list': dict({
|
||||
'enum': list([
|
||||
'Shopping',
|
||||
]),
|
||||
'type': 'string',
|
||||
}),
|
||||
}),
|
||||
'required': list([
|
||||
'todo_list',
|
||||
]),
|
||||
'type': 'object',
|
||||
}),
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
+260
-25
@@ -2,16 +2,19 @@
|
||||
|
||||
from datetime import timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
from homeassistant.components import calendar, todo
|
||||
from homeassistant.components import calendar, script, todo
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.components.intent import async_register_timer_handler
|
||||
from homeassistant.components.script import ScriptConfig
|
||||
from homeassistant.core import Context, HomeAssistant, State, SupportsResponse
|
||||
from homeassistant.core import Context, HomeAssistant, State, SupportsResponse, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
@@ -53,6 +56,24 @@ class MyAPI(llm.API):
|
||||
return llm.APIInstance(self, self.prompt, llm_context, self.tools)
|
||||
|
||||
|
||||
class _StubTool(llm.Tool):
|
||||
"""Minimal tool for registry tests."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialize the stub tool."""
|
||||
self.name = name
|
||||
self.description = f"{name} description"
|
||||
|
||||
async def async_call(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
tool_input: llm.ToolInput,
|
||||
llm_context: llm.LLMContext,
|
||||
) -> JsonObjectType:
|
||||
"""Return an empty result."""
|
||||
return {}
|
||||
|
||||
|
||||
async def test_get_api_no_existing(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
@@ -128,6 +149,72 @@ async def test_multiple_apis(hass: HomeAssistant, llm_context: llm.LLMContext) -
|
||||
assert await llm.async_get_api(hass, "test-2", llm_context)
|
||||
|
||||
|
||||
async def test_register_tool_provider(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
"""Test registering and unregistering a tool provider."""
|
||||
tool = _StubTool("my_tool")
|
||||
|
||||
@callback
|
||||
def provider(_hass: HomeAssistant, _llm_context: llm.LLMContext) -> llm.LLMTools:
|
||||
return llm.LLMTools(tools=[tool], prompt="use my_tool wisely")
|
||||
|
||||
unreg = llm.async_register_tool_provider(hass, provider, apis={"assist": None})
|
||||
|
||||
result = llm._async_get_registered_tools(hass, "assist", llm_context)
|
||||
assert result.tools == [tool]
|
||||
assert result.prompt == "use my_tool wisely"
|
||||
|
||||
unreg()
|
||||
result = llm._async_get_registered_tools(hass, "assist", llm_context)
|
||||
assert result.tools == []
|
||||
assert result.prompt is None
|
||||
|
||||
|
||||
async def test_register_tool_static(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
"""Test the static single-tool registration convenience."""
|
||||
tool = _StubTool("static_tool")
|
||||
unreg = llm.async_register_tool(hass, tool, apis={"assist": None})
|
||||
|
||||
result = llm._async_get_registered_tools(hass, "assist", llm_context)
|
||||
assert result.tools == [tool]
|
||||
assert result.prompt is None
|
||||
|
||||
unreg()
|
||||
assert llm._async_get_registered_tools(hass, "assist", llm_context).tools == []
|
||||
|
||||
|
||||
async def test_register_tool_provider_multiple_apis(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
"""Test a provider in multiple APIs and prompt merging across providers."""
|
||||
tool_a = _StubTool("tool_a")
|
||||
tool_b = _StubTool("tool_b")
|
||||
|
||||
@callback
|
||||
def provider_a(_hass: HomeAssistant, _llm_context: llm.LLMContext) -> llm.LLMTools:
|
||||
return llm.LLMTools(tools=[tool_a], prompt="prompt a")
|
||||
|
||||
@callback
|
||||
def provider_b(_hass: HomeAssistant, _llm_context: llm.LLMContext) -> llm.LLMTools:
|
||||
return llm.LLMTools(tools=[tool_b], prompt="prompt b")
|
||||
|
||||
llm.async_register_tool_provider(
|
||||
hass, provider_a, apis={"assist": None, "other": None}
|
||||
)
|
||||
llm.async_register_tool_provider(hass, provider_b, apis={"assist": None})
|
||||
|
||||
assist = llm._async_get_registered_tools(hass, "assist", llm_context)
|
||||
assert assist.tools == [tool_a, tool_b]
|
||||
assert assist.prompt == "prompt a\nprompt b"
|
||||
|
||||
other = llm._async_get_registered_tools(hass, "other", llm_context)
|
||||
assert other.tools == [tool_a]
|
||||
assert other.prompt == "prompt a"
|
||||
|
||||
|
||||
async def test_call_tool_no_existing(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
@@ -148,6 +235,7 @@ async def test_assist_api(
|
||||
) -> None:
|
||||
"""Test Assist API."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
|
||||
entity_registry.async_get_or_create(
|
||||
"light",
|
||||
@@ -327,12 +415,121 @@ async def test_assist_api(
|
||||
}
|
||||
|
||||
|
||||
def _normalize_schema(value: Any) -> Any:
|
||||
"""Recursively sort scalar lists (e.g. enum values) for a stable snapshot.
|
||||
|
||||
Some tool parameter schemas build enum options from Python sets, so their
|
||||
order varies per process. Order is semantically irrelevant here.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return {key: _normalize_schema(val) for key, val in value.items()}
|
||||
if isinstance(value, list):
|
||||
items = [_normalize_schema(item) for item in value]
|
||||
if all(isinstance(item, (str, int, float, bool)) for item in items):
|
||||
return sorted(items, key=repr)
|
||||
return items
|
||||
return value
|
||||
|
||||
|
||||
async def test_assist_api_snapshot(
|
||||
hass: HomeAssistant,
|
||||
entity_registry: er.EntityRegistry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Golden snapshot of the Assist API prompt + tools.
|
||||
|
||||
Behavior-parity net for the v1 tool-platform refactor: the assembled prompt
|
||||
and the full serialized tool set (name, description, parameters) must stay
|
||||
identical as built-in tools and intents move out of AssistAPI into per-
|
||||
integration platforms.
|
||||
"""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
assert await async_setup_component(hass, "calendar", {})
|
||||
assert await async_setup_component(hass, "todo", {})
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"script",
|
||||
{
|
||||
"script": {
|
||||
"test_script": {
|
||||
"description": "This is a test script",
|
||||
"sequence": [],
|
||||
"fields": {
|
||||
"beer": {"description": "Number of beers"},
|
||||
"wine": {},
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
entry = MockConfigEntry(title=None)
|
||||
entry.add_to_hass(hass)
|
||||
device = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections={("test", "1234")},
|
||||
suggested_area="Test Area",
|
||||
)
|
||||
|
||||
# Expose one entity per tool-bearing domain so every built-in tool appears.
|
||||
for domain, object_id, name in (
|
||||
("light", "kitchen", "Kitchen"),
|
||||
("calendar", "personal", "Personal"),
|
||||
("todo", "shopping", "Shopping"),
|
||||
):
|
||||
created = entity_registry.async_get_or_create(
|
||||
domain,
|
||||
"test",
|
||||
f"mock-{object_id}",
|
||||
original_name=name,
|
||||
suggested_object_id=object_id,
|
||||
)
|
||||
hass.states.async_set(created.entity_id, "on", {"friendly_name": name})
|
||||
async_expose_entity(hass, "conversation", created.entity_id, True)
|
||||
|
||||
async_expose_entity(hass, "conversation", "script.test_script", True)
|
||||
async_register_timer_handler(hass, device.id, lambda *args: None)
|
||||
|
||||
llm_context = llm.LLMContext(
|
||||
platform="test_platform",
|
||||
context=Context(),
|
||||
language="*",
|
||||
assistant="conversation",
|
||||
device_id=device.id,
|
||||
)
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
# Order-independent: the v1 refactor moves tools/prompt fragments out of
|
||||
# AssistAPI into per-integration registrations, which changes their order but
|
||||
# not their content. Compare tools as a name-keyed mapping and the prompt as a
|
||||
# set of lines so each migration step verifies same-tools/same-content/
|
||||
# same-prompt regardless of ordering. (Final ordering vs dev is checked
|
||||
# separately at the end.)
|
||||
assert sorted(
|
||||
line for line in api.api_prompt.splitlines() if line.strip()
|
||||
) == snapshot(name="prompt")
|
||||
assert {
|
||||
tool.name: {
|
||||
"description": tool.description,
|
||||
"parameters": _normalize_schema(
|
||||
convert(tool.parameters, custom_serializer=api.custom_serializer)
|
||||
),
|
||||
}
|
||||
for tool in api.tools
|
||||
} == snapshot(name="tools")
|
||||
|
||||
|
||||
async def test_assist_api_get_timer_tools(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
"""Test getting timer tools with Assist API."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
await hass.async_block_till_done()
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
assert "HassStartTimer" not in [tool.name for tool in api.tools]
|
||||
@@ -351,6 +548,7 @@ async def test_assist_api_tools(
|
||||
"""Test getting timer tools with Assist API."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
|
||||
llm_context.device_id = "test_device"
|
||||
|
||||
@@ -381,10 +579,32 @@ async def test_assist_api_tools(
|
||||
]
|
||||
|
||||
|
||||
async def test_assist_api_registered_tools(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
"""Test Assist API includes tools and prompt from the tool registry."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
|
||||
tool = _StubTool("registered_tool")
|
||||
|
||||
@callback
|
||||
def provider(_hass: HomeAssistant, _llm_context: llm.LLMContext) -> llm.LLMTools:
|
||||
return llm.LLMTools(tools=[tool], prompt="registered prompt fragment")
|
||||
|
||||
llm.async_register_tool_provider(hass, provider, apis={"assist": None})
|
||||
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
assert "registered_tool" in [t.name for t in api.tools]
|
||||
assert "registered prompt fragment" in api.api_prompt
|
||||
|
||||
|
||||
async def test_assist_api_description(
|
||||
hass: HomeAssistant, llm_context: llm.LLMContext
|
||||
) -> None:
|
||||
"""Test intent description with Assist API."""
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
|
||||
class MyIntentHandler(intent.IntentHandler):
|
||||
intent_type = "test_intent"
|
||||
@@ -410,6 +630,7 @@ async def test_assist_api_prompt(
|
||||
"""Test prompt for the assist API."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
context = Context()
|
||||
llm_context = llm.LLMContext(
|
||||
platform="test_platform",
|
||||
@@ -418,6 +639,7 @@ async def test_assist_api_prompt(
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
assert api.api_prompt == (
|
||||
@@ -762,11 +984,11 @@ Static Context: An overview of the areas and the devices in this smart home:
|
||||
)
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
assert api.api_prompt == (
|
||||
f"""{first_part_prompt}
|
||||
{dynamic_context_prompt}
|
||||
{stateless_exposed_entities_prompt}
|
||||
f"""{stateless_exposed_entities_prompt}
|
||||
{area_prompt}
|
||||
{no_timer_prompt}"""
|
||||
{first_part_prompt}
|
||||
{no_timer_prompt}
|
||||
{dynamic_context_prompt}"""
|
||||
)
|
||||
|
||||
# Verify that the GetLiveContext tool returns the same results
|
||||
@@ -787,11 +1009,11 @@ Static Context: An overview of the areas and the devices in this smart home:
|
||||
)
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
assert api.api_prompt == (
|
||||
f"""{first_part_prompt}
|
||||
{dynamic_context_prompt}
|
||||
{stateless_exposed_entities_prompt}
|
||||
f"""{stateless_exposed_entities_prompt}
|
||||
{area_prompt}
|
||||
{no_timer_prompt}"""
|
||||
{first_part_prompt}
|
||||
{no_timer_prompt}
|
||||
{dynamic_context_prompt}"""
|
||||
)
|
||||
|
||||
# Add floor
|
||||
@@ -804,11 +1026,11 @@ Static Context: An overview of the areas and the devices in this smart home:
|
||||
)
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
assert api.api_prompt == (
|
||||
f"""{first_part_prompt}
|
||||
{dynamic_context_prompt}
|
||||
{stateless_exposed_entities_prompt}
|
||||
f"""{stateless_exposed_entities_prompt}
|
||||
{area_prompt}
|
||||
{no_timer_prompt}"""
|
||||
{first_part_prompt}
|
||||
{no_timer_prompt}
|
||||
{dynamic_context_prompt}"""
|
||||
)
|
||||
|
||||
# Register device for timers
|
||||
@@ -817,10 +1039,10 @@ Static Context: An overview of the areas and the devices in this smart home:
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
# The no_timer_prompt is gone
|
||||
assert api.api_prompt == (
|
||||
f"""{first_part_prompt}
|
||||
{dynamic_context_prompt}
|
||||
{stateless_exposed_entities_prompt}
|
||||
{area_prompt}"""
|
||||
f"""{stateless_exposed_entities_prompt}
|
||||
{area_prompt}
|
||||
{first_part_prompt}
|
||||
{dynamic_context_prompt}"""
|
||||
)
|
||||
|
||||
|
||||
@@ -833,6 +1055,7 @@ async def test_get_live_context_tool_filter(
|
||||
"""Test the filter parameters of the GetLiveContext tool."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
context = Context()
|
||||
llm_context = llm.LLMContext(
|
||||
platform="test_platform",
|
||||
@@ -929,6 +1152,7 @@ async def test_get_live_context_tool_filter(
|
||||
hass.states.async_set(office_ac.entity_id, "cool")
|
||||
hass.states.async_set(kitchen_ac.entity_id, "heat")
|
||||
|
||||
await hass.async_block_till_done()
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
# Filter by area and domain (example 1)
|
||||
@@ -1168,6 +1392,7 @@ async def test_script_tool(
|
||||
"""Test ScriptTool for the assist API."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "intent", {})
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
context = Context()
|
||||
llm_context = llm.LLMContext(
|
||||
platform="test_platform",
|
||||
@@ -1223,11 +1448,11 @@ async def test_script_tool(
|
||||
area = area_registry.async_create("Living room")
|
||||
floor = floor_registry.async_create("2")
|
||||
|
||||
assert llm.ACTION_PARAMETERS_CACHE not in hass.data
|
||||
assert script.llm.ACTION_PARAMETERS_CACHE not in hass.data
|
||||
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||
tools = [tool for tool in api.tools if isinstance(tool, script.llm.ScriptTool)]
|
||||
assert len(tools) == 2
|
||||
|
||||
tool = tools[0]
|
||||
@@ -1247,7 +1472,7 @@ async def test_script_tool(
|
||||
}
|
||||
assert tool.parameters.schema == schema
|
||||
|
||||
assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {
|
||||
assert hass.data[script.llm.ACTION_PARAMETERS_CACHE]["script"] == {
|
||||
"test_script": (
|
||||
"This is a test script. Aliases: ['script alias', 'script name']",
|
||||
vol.Schema(schema),
|
||||
@@ -1347,11 +1572,11 @@ async def test_script_tool(
|
||||
):
|
||||
await hass.services.async_call("script", "reload", blocking=True)
|
||||
|
||||
assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {}
|
||||
assert hass.data[script.llm.ACTION_PARAMETERS_CACHE]["script"] == {}
|
||||
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||
tools = [tool for tool in api.tools if isinstance(tool, script.llm.ScriptTool)]
|
||||
assert len(tools) == 2
|
||||
|
||||
tool = tools[0]
|
||||
@@ -1363,7 +1588,7 @@ async def test_script_tool(
|
||||
schema = {vol.Required("beer", description="Number of beers"): cv.string}
|
||||
assert tool.parameters.schema == schema
|
||||
|
||||
assert hass.data[llm.ACTION_PARAMETERS_CACHE]["script"] == {
|
||||
assert hass.data[script.llm.ACTION_PARAMETERS_CACHE]["script"] == {
|
||||
"test_script": (
|
||||
"This is a new test script. Aliases: ['script alias', 'script name']",
|
||||
vol.Schema(schema),
|
||||
@@ -1378,6 +1603,7 @@ async def test_script_tool(
|
||||
async def test_script_tool_name(hass: HomeAssistant) -> None:
|
||||
"""Test that script tool name is not started with a digit."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
context = Context()
|
||||
llm_context = llm.LLMContext(
|
||||
platform="test_platform",
|
||||
@@ -1407,7 +1633,7 @@ async def test_script_tool_name(hass: HomeAssistant) -> None:
|
||||
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
|
||||
tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
|
||||
tools = [tool for tool in api.tools if isinstance(tool, script.llm.ScriptTool)]
|
||||
assert len(tools) == 1
|
||||
|
||||
tool = tools[0]
|
||||
@@ -1692,6 +1918,8 @@ async def test_selector_serializer(
|
||||
async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
|
||||
"""Test the calendar get events tool."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
assert await async_setup_component(hass, "calendar", {})
|
||||
hass.states.async_set(
|
||||
"calendar.test_calendar", "on", {"friendly_name": "Mock Calendar Name"}
|
||||
)
|
||||
@@ -1704,6 +1932,8 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
)
|
||||
# Wait for the llm integration to discover the calendar tools platform.
|
||||
await hass.async_block_till_done()
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
tool = next(
|
||||
(tool for tool in api.tools if tool.name == "calendar_get_events"), None
|
||||
@@ -1793,6 +2023,7 @@ async def test_calendar_get_events_tool(hass: HomeAssistant) -> None:
|
||||
async def test_todo_get_items_tool(hass: HomeAssistant) -> None:
|
||||
"""Test the todo get items tool."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
assert await async_setup_component(hass, "todo", {})
|
||||
hass.states.async_set(
|
||||
"todo.test_list", "0", {"friendly_name": "Mock Todo List Name"}
|
||||
@@ -1806,6 +2037,8 @@ async def test_todo_get_items_tool(hass: HomeAssistant) -> None:
|
||||
assistant="conversation",
|
||||
device_id=None,
|
||||
)
|
||||
# Wait for the llm integration to discover the todo tools platform.
|
||||
await hass.async_block_till_done()
|
||||
api = await llm.async_get_api(hass, "assist", llm_context)
|
||||
tool = next((tool for tool in api.tools if tool.name == "todo_get_items"), None)
|
||||
assert tool is not None
|
||||
@@ -1905,6 +2138,7 @@ async def test_get_date_time_tool(hass: HomeAssistant) -> None:
|
||||
"""Test the GetDateTime tool."""
|
||||
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
context = Context()
|
||||
llm_context = llm.LLMContext(
|
||||
platform="test_platform",
|
||||
@@ -1939,6 +2173,7 @@ async def test_get_date_time_tool(hass: HomeAssistant) -> None:
|
||||
async def test_no_tools_exposed(hass: HomeAssistant) -> None:
|
||||
"""Test that tools are not exposed when no entities are exposed."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
assert await async_setup_component(hass, "llm", {})
|
||||
context = Context()
|
||||
llm_context = llm.LLMContext(
|
||||
platform="test_platform",
|
||||
|
||||
Reference in New Issue
Block a user