Add the Model Context Protocol integration (#135058)

* Add the Model Context Protocol integration

* Improvements to mcp integration

* Move the API prompt constant

* Update config flow error handling

* Update test descriptions

* Update tests/components/mcp/test_config_flow.py

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>

* Update tests/components/mcp/test_config_flow.py

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>

* Address PR feedback

* Update homeassistant/components/mcp/coordinator.py

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>

* Move tool parsing to the coordinator

* Update session handling not to use a context manager

---------

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Allen Porter
2025-01-27 11:38:52 -08:00
committed by GitHub
parent 85540cea3f
commit 58b4556a1d
19 changed files with 1011 additions and 0 deletions

View File

@ -316,6 +316,7 @@ homeassistant.components.manual.*
homeassistant.components.mastodon.*
homeassistant.components.matrix.*
homeassistant.components.matter.*
homeassistant.components.mcp.*
homeassistant.components.mcp_server.*
homeassistant.components.mealie.*
homeassistant.components.media_extractor.*

2
CODEOWNERS generated
View File

@ -891,6 +891,8 @@ build.json @home-assistant/supervisor
/tests/components/matrix/ @PaarthShah
/homeassistant/components/matter/ @home-assistant/matter
/tests/components/matter/ @home-assistant/matter
/homeassistant/components/mcp/ @allenporter
/tests/components/mcp/ @allenporter
/homeassistant/components/mcp_server/ @allenporter
/tests/components/mcp_server/ @allenporter
/homeassistant/components/mealie/ @joostlek @andrew-codechimp

View File

@ -0,0 +1,69 @@
"""The Model Context Protocol integration."""
from __future__ import annotations
from dataclasses import dataclass
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from .const import DOMAIN
from .coordinator import ModelContextProtocolCoordinator
from .types import ModelContextProtocolConfigEntry
__all__ = [
"DOMAIN",
"async_setup_entry",
"async_unload_entry",
]
API_PROMPT = "The following tools are available from a remote server named {name}."
async def async_setup_entry(
hass: HomeAssistant, entry: ModelContextProtocolConfigEntry
) -> bool:
"""Set up Model Context Protocol from a config entry."""
coordinator = ModelContextProtocolCoordinator(hass, entry)
await coordinator.async_config_entry_first_refresh()
unsub = llm.async_register_api(
hass,
ModelContextProtocolAPI(
hass=hass,
id=f"{DOMAIN}-{entry.entry_id}",
name=entry.title,
coordinator=coordinator,
),
)
entry.async_on_unload(unsub)
entry.runtime_data = coordinator
entry.async_on_unload(coordinator.close)
return True
async def async_unload_entry(
hass: HomeAssistant, entry: ModelContextProtocolConfigEntry
) -> bool:
"""Unload a config entry."""
return True
@dataclass(kw_only=True)
class ModelContextProtocolAPI(llm.API):
"""Define an object to hold the Model Context Protocol API."""
coordinator: ModelContextProtocolCoordinator
async def async_get_api_instance(
self, llm_context: llm.LLMContext
) -> llm.APIInstance:
"""Return the instance of the API."""
return llm.APIInstance(
self,
API_PROMPT.format(name=self.name),
llm_context,
tools=self.coordinator.data,
)

View File

@ -0,0 +1,111 @@
"""Config flow for the Model Context Protocol integration."""
from __future__ import annotations
import logging
from typing import Any
import httpx
import voluptuous as vol
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
from homeassistant.const import CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from .const import DOMAIN
from .coordinator import mcp_client
_LOGGER = logging.getLogger(__name__)
STEP_USER_DATA_SCHEMA = vol.Schema(
{
vol.Required(CONF_URL): str,
}
)
async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str, Any]:
"""Validate the user input and connect to the MCP server."""
url = data[CONF_URL]
try:
cv.url(url) # Cannot be added to schema directly
except vol.Invalid as error:
raise InvalidUrl from error
try:
async with mcp_client(url) as session:
response = await session.initialize()
except httpx.TimeoutException as error:
_LOGGER.info("Timeout connecting to MCP server: %s", error)
raise TimeoutConnectError from error
except httpx.HTTPStatusError as error:
_LOGGER.info("Cannot connect to MCP server: %s", error)
if error.response.status_code == 401:
raise InvalidAuth from error
raise CannotConnect from error
except httpx.HTTPError as error:
_LOGGER.info("Cannot connect to MCP server: %s", error)
raise CannotConnect from error
if not response.capabilities.tools:
raise MissingCapabilities(
f"MCP Server {url} does not support 'Tools' capability"
)
return {"title": response.serverInfo.name}
class ModelContextProtocolConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Model Context Protocol."""
VERSION = 1
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the initial step."""
errors: dict[str, str] = {}
if user_input is not None:
try:
info = await validate_input(self.hass, user_input)
except InvalidUrl:
errors[CONF_URL] = "invalid_url"
except TimeoutConnectError:
errors["base"] = "timeout_connect"
except CannotConnect:
errors["base"] = "cannot_connect"
except InvalidAuth:
return self.async_abort(reason="invalid_auth")
except MissingCapabilities:
return self.async_abort(reason="missing_capabilities")
except Exception:
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
else:
self._async_abort_entries_match({CONF_URL: user_input[CONF_URL]})
return self.async_create_entry(title=info["title"], data=user_input)
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
)
class InvalidUrl(HomeAssistantError):
"""Error to indicate the URL format is invalid."""
class CannotConnect(HomeAssistantError):
"""Error to indicate we cannot connect."""
class TimeoutConnectError(HomeAssistantError):
"""Error to indicate we cannot connect."""
class InvalidAuth(HomeAssistantError):
"""Error to indicate there is invalid auth."""
class MissingCapabilities(HomeAssistantError):
"""Error to indicate that the MCP server is missing required capabilities."""

View File

@ -0,0 +1,3 @@
"""Constants for the Model Context Protocol integration."""
DOMAIN = "mcp"

View File

@ -0,0 +1,171 @@
"""Types for the Model Context Protocol integration."""
import asyncio
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import datetime
import logging
import httpx
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
import voluptuous as vol
from voluptuous_openapi import convert_to_voluptuous
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from homeassistant.util.json import JsonObjectType
from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
UPDATE_INTERVAL = datetime.timedelta(minutes=30)
TIMEOUT = 10
@asynccontextmanager
async def mcp_client(url: str) -> AsyncGenerator[ClientSession]:
"""Create a server-sent event MCP client.
This is an asynccontext manager that exists to wrap other async context managers
so that the coordinator has a single object to manage.
"""
try:
async with sse_client(url=url) as streams, ClientSession(*streams) as session:
await session.initialize()
yield session
except ExceptionGroup as err:
raise err.exceptions[0] from err
class ModelContextProtocolTool(llm.Tool):
"""A Tool exposed over the Model Context Protocol."""
def __init__(
self,
name: str,
description: str | None,
parameters: vol.Schema,
session: ClientSession,
) -> None:
"""Initialize the tool."""
self.name = name
self.description = description
self.parameters = parameters
self.session = session
async def async_call(
self,
hass: HomeAssistant,
tool_input: llm.ToolInput,
llm_context: llm.LLMContext,
) -> JsonObjectType:
"""Call the tool."""
try:
result = await self.session.call_tool(
tool_input.tool_name, tool_input.tool_args
)
except httpx.HTTPStatusError as error:
raise HomeAssistantError(f"Error when calling tool: {error}") from error
return result.model_dump(exclude_unset=True, exclude_none=True)
class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]):
"""Define an object to hold MCP data."""
config_entry: ConfigEntry
_session: ClientSession | None = None
_setup_error: Exception | None = None
def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Initialize ModelContextProtocolCoordinator."""
super().__init__(
hass,
logger=_LOGGER,
name=DOMAIN,
config_entry=config_entry,
update_interval=UPDATE_INTERVAL,
)
self._stop = asyncio.Event()
async def _async_setup(self) -> None:
"""Set up the client connection."""
connected = asyncio.Event()
stop = asyncio.Event()
self.config_entry.async_create_background_task(
self.hass, self._connect(connected, stop), "mcp-client"
)
try:
async with asyncio.timeout(TIMEOUT):
await connected.wait()
self._stop = stop
finally:
if self._setup_error is not None:
raise self._setup_error
async def _connect(self, connected: asyncio.Event, stop: asyncio.Event) -> None:
"""Create a server-sent event MCP client."""
url = self.config_entry.data[CONF_URL]
try:
async with (
sse_client(url=url) as streams,
ClientSession(*streams) as session,
):
await session.initialize()
self._session = session
connected.set()
await stop.wait()
except httpx.HTTPStatusError as err:
self._setup_error = err
_LOGGER.debug("Error connecting to MCP server: %s", err)
raise UpdateFailed(f"Error connecting to MCP server: {err}") from err
except ExceptionGroup as err:
self._setup_error = err.exceptions[0]
_LOGGER.debug("Error connecting to MCP server: %s", err)
raise UpdateFailed(
"Error connecting to MCP server: {err.exceptions[0]}"
) from err.exceptions[0]
finally:
self._session = None
async def close(self) -> None:
"""Close the client connection."""
if self._stop is not None:
self._stop.set()
async def _async_update_data(self) -> list[llm.Tool]:
"""Fetch data from API endpoint.
This is the place to pre-process the data to lookup tables
so entities can quickly look up their data.
"""
if self._session is None:
raise UpdateFailed("No session available")
try:
result = await self._session.list_tools()
except httpx.HTTPError as err:
raise UpdateFailed(f"Error communicating with API: {err}") from err
_LOGGER.debug("Received tools: %s", result.tools)
tools: list[llm.Tool] = []
for tool in result.tools:
try:
parameters = convert_to_voluptuous(tool.inputSchema)
except Exception as err:
raise UpdateFailed(
f"Error converting schema {err}: {tool.inputSchema}"
) from err
tools.append(
ModelContextProtocolTool(
tool.name,
tool.description,
parameters,
self._session,
)
)
return tools

View File

@ -0,0 +1,10 @@
{
"domain": "mcp",
"name": "Model Context Protocol",
"codeowners": ["@allenporter"],
"config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/mcp",
"iot_class": "local_polling",
"quality_scale": "silver",
"requirements": ["mcp==1.1.2"]
}

View File

@ -0,0 +1,88 @@
rules:
# Bronze
action-setup:
status: exempt
comment: Integration does not have actions.
appropriate-polling: done
brands: done
common-modules: done
config-flow-test-coverage: done
config-flow: done
dependency-transparency: done
docs-actions:
status: exempt
comment: Integration does not have actions.
docs-high-level-description: done
docs-installation-instructions: done
docs-removal-instructions: done
entity-event-setup:
status: exempt
comment: Integration does not have entities.
entity-unique-id:
status: exempt
comment: Integration does not have entities.
has-entity-name:
status: exempt
comment: Integration does not have entities.
runtime-data: done
test-before-configure: done
test-before-setup: done
unique-config-entry: done
# Silver
action-exceptions:
status: exempt
comment: Integration does not have actions.
config-entry-unloading: done
docs-configuration-parameters: done
docs-installation-parameters: done
entity-unavailable:
status: exempt
comment: Integration does not have entities.
integration-owner: done
log-when-unavailable: done
parallel-updates:
status: exempt
comment: Integration does not have platforms.
reauthentication-flow:
status: exempt
comment: Integration does not support authentication.
test-coverage: done
# Gold
devices:
status: exempt
comment: Integration does not have devices.
diagnostics: todo
discovery-update-info: todo
discovery: todo
docs-data-update: done
docs-examples: done
docs-known-limitations: done
docs-supported-devices: done
docs-supported-functions: done
docs-troubleshooting: done
docs-use-cases: done
dynamic-devices: todo
entity-category:
status: exempt
comment: Integration does not have entities.
entity-device-class:
status: exempt
comment: Integration does not have entities.
entity-disabled-by-default:
status: exempt
comment: Integration does not have entities.
entity-translations:
status: exempt
comment: Integration does not have entities.
exception-translations: todo
icon-translations: todo
reconfiguration-flow: todo
repair-issues: todo
stale-devices: todo
# Platinum
async-dependency: done
inject-websession: todo
strict-typing: done

View File

@ -0,0 +1,25 @@
{
"config": {
"step": {
"user": {
"data": {
"url": "[%key:common::config_flow::data::url%]"
},
"data_description": {
"url": "The remote MCP server URL for the SSE endpoint, for example http://example/sse"
}
}
},
"error": {
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
"unknown": "[%key:common::config_flow::error::unknown%]",
"timeout_connect": "[%key:common::config_flow::error::timeout_connect%]",
"invalid_url": "Must be a valid MCP server URL e.g. https://example.com/sse"
},
"abort": {
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
"missing_capabilities": "The MCP server does not support a required capability (Tools)",
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]"
}
}
}

View File

@ -0,0 +1,7 @@
"""Types for the Model Context Protocol integration."""
from homeassistant.config_entries import ConfigEntry
from .coordinator import ModelContextProtocolCoordinator
type ModelContextProtocolConfigEntry = ConfigEntry[ModelContextProtocolCoordinator]

View File

@ -358,6 +358,7 @@ FLOWS = {
"mailgun",
"mastodon",
"matter",
"mcp",
"mcp_server",
"mealie",
"meater",

View File

@ -3607,6 +3607,12 @@
"config_flow": true,
"iot_class": "local_push"
},
"mcp": {
"name": "Model Context Protocol",
"integration_type": "hub",
"config_flow": true,
"iot_class": "local_polling"
},
"mcp_server": {
"name": "Model Context Protocol Server",
"integration_type": "service",

10
mypy.ini generated
View File

@ -2916,6 +2916,16 @@ disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.mcp.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.mcp_server.*]
check_untyped_defs = true
disallow_incomplete_defs = true

1
requirements_all.txt generated
View File

@ -1364,6 +1364,7 @@ maxcube-api==0.4.3
# homeassistant.components.mythicbeastsdns
mbddns==0.1.2
# homeassistant.components.mcp
# homeassistant.components.mcp_server
mcp==1.1.2

View File

@ -1142,6 +1142,7 @@ maxcube-api==0.4.3
# homeassistant.components.mythicbeastsdns
mbddns==0.1.2
# homeassistant.components.mcp
# homeassistant.components.mcp_server
mcp==1.1.2

View File

@ -0,0 +1 @@
"""Tests for the Model Context Protocol integration."""

View File

@ -0,0 +1,45 @@
"""Common fixtures for the Model Context Protocol tests."""
from collections.abc import Generator
from unittest.mock import AsyncMock, patch
import pytest
from homeassistant.components.mcp.const import DOMAIN
from homeassistant.const import CONF_URL
from homeassistant.core import HomeAssistant
from tests.common import MockConfigEntry
TEST_API_NAME = "Memory Server"
@pytest.fixture
def mock_setup_entry() -> Generator[AsyncMock]:
"""Override async_setup_entry."""
with patch(
"homeassistant.components.mcp.async_setup_entry", return_value=True
) as mock_setup_entry:
yield mock_setup_entry
@pytest.fixture
def mock_mcp_client() -> Generator[AsyncMock]:
"""Fixture to mock the MCP client."""
with (
patch("homeassistant.components.mcp.coordinator.sse_client"),
patch("homeassistant.components.mcp.coordinator.ClientSession") as mock_session,
):
yield mock_session.return_value.__aenter__
@pytest.fixture(name="config_entry")
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"""Fixture to load the integration."""
config_entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_URL: "http://1.1.1.1/sse"},
title=TEST_API_NAME,
)
config_entry.add_to_hass(hass)
return config_entry

View File

@ -0,0 +1,234 @@
"""Test the Model Context Protocol config flow."""
from typing import Any
from unittest.mock import AsyncMock, Mock
import httpx
import pytest
from homeassistant import config_entries
from homeassistant.components.mcp.const import DOMAIN
from homeassistant.const import CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from .conftest import TEST_API_NAME
from tests.common import MockConfigEntry
async def test_form(
hass: HomeAssistant, mock_setup_entry: AsyncMock, mock_mcp_client: Mock
) -> None:
"""Test the complete configuration flow."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {}
response = Mock()
response.serverInfo.name = TEST_API_NAME
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == TEST_API_NAME
assert result["data"] == {
CONF_URL: "http://1.1.1.1/sse",
}
assert len(mock_setup_entry.mock_calls) == 1
@pytest.mark.parametrize(
("side_effect", "expected_error"),
[
(httpx.TimeoutException("Some timeout"), "timeout_connect"),
(
httpx.HTTPStatusError("", request=None, response=httpx.Response(500)),
"cannot_connect",
),
(httpx.HTTPError("Some HTTP error"), "cannot_connect"),
(Exception, "unknown"),
],
)
async def test_form_mcp_client_error(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
side_effect: Exception,
expected_error: str,
) -> None:
"""Test we handle different client library errors."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
mock_mcp_client.side_effect = side_effect
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
},
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"base": expected_error}
# Reset the error and make sure the config flow can resume successfully.
mock_mcp_client.side_effect = None
response = Mock()
response.serverInfo.name = TEST_API_NAME
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == TEST_API_NAME
assert result["data"] == {
CONF_URL: "http://1.1.1.1/sse",
}
assert len(mock_setup_entry.mock_calls) == 1
@pytest.mark.parametrize(
("side_effect", "expected_error"),
[
(
httpx.HTTPStatusError("", request=None, response=httpx.Response(401)),
"invalid_auth",
),
],
)
async def test_form_mcp_client_error_abort(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
side_effect: Exception,
expected_error: str,
) -> None:
"""Test we handle different client library errors that end with an abort."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
mock_mcp_client.side_effect = side_effect
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == expected_error
@pytest.mark.parametrize(
"user_input",
[
({CONF_URL: "not a url"}),
({CONF_URL: "rtsp://1.1.1.1"}),
],
)
async def test_input_form_validation_error(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
user_input: dict[str, Any],
) -> None:
"""Test we handle invalid auth."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input,
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {CONF_URL: "invalid_url"}
# Reset the error and make sure the config flow can resume successfully.
response = Mock()
response.serverInfo.name = TEST_API_NAME
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == TEST_API_NAME
assert result["data"] == {
CONF_URL: "http://1.1.1.1/sse",
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_unique_url(
hass: HomeAssistant, mock_setup_entry: AsyncMock, mock_mcp_client: Mock
) -> None:
"""Test that the same url cannot be configured twice."""
config_entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_URL: "http://1.1.1.1/sse"},
title=TEST_API_NAME,
)
config_entry.add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {}
response = Mock()
response.serverInfo.name = TEST_API_NAME
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "already_configured"
async def test_server_missing_capbilities(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
mock_mcp_client: Mock,
) -> None:
"""Test we handle different client library errors."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
response = Mock()
response.serverInfo.name = TEST_API_NAME
response.capabilities.tools = None
mock_mcp_client.return_value.initialize.return_value = response
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_URL: "http://1.1.1.1/sse",
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "missing_capabilities"

View File

@ -0,0 +1,225 @@
"""Tests for the Model Context Protocol component."""
import re
from unittest.mock import Mock, patch
import httpx
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
import pytest
import voluptuous as vol
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from .conftest import TEST_API_NAME
from tests.common import MockConfigEntry
SEARCH_MEMORY_TOOL = Tool(
name="search_memory",
description="Search memory for relevant context based on a query.",
inputSchema={
"type": "object",
"required": ["query"],
"properties": {
"query": {
"type": "string",
"description": "A free text query to search context for.",
}
},
},
)
SAVE_MEMORY_TOOL = Tool(
name="save_memory",
description="Save a memory context.",
inputSchema={
"type": "object",
"required": ["context"],
"properties": {
"context": {
"type": "object",
"description": "The context to save.",
"properties": {
"fact": {
"type": "string",
"description": "The key for the context.",
},
},
},
},
},
)
def create_llm_context() -> llm.LLMContext:
"""Create a test LLM context."""
return llm.LLMContext(
platform="test_platform",
context=Context(),
user_prompt="test_text",
language="*",
assistant="conversation",
device_id=None,
)
async def test_init(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test the integration is initialized and can be unloaded cleanly."""
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.LOADED
await hass.config_entries.async_unload(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.NOT_LOADED
async def test_mcp_server_failure(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test the integration fails to setup if the server fails initialization."""
mock_mcp_client.side_effect = httpx.HTTPStatusError(
"", request=None, response=httpx.Response(500)
)
with patch("homeassistant.components.mcp.coordinator.TIMEOUT", 1):
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.SETUP_RETRY
async def test_list_tools_failure(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test the integration fails to load if the first data fetch returns an error."""
mock_mcp_client.return_value.list_tools.side_effect = httpx.HTTPStatusError(
"", request=None, response=httpx.Response(500)
)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.SETUP_RETRY
async def test_llm_get_api_tools(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test MCP tools are returned as LLM API tools."""
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
tools=[SEARCH_MEMORY_TOOL, SAVE_MEMORY_TOOL],
)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.LOADED
apis = llm.async_get_apis(hass)
api = next(iter([api for api in apis if api.name == TEST_API_NAME]))
assert api
api_instance = await api.async_get_api_instance(create_llm_context())
assert len(api_instance.tools) == 2
tool = api_instance.tools[0]
assert tool.name == "search_memory"
assert tool.description == "Search memory for relevant context based on a query."
with pytest.raises(
vol.Invalid, match=re.escape("required key not provided @ data['query']")
):
tool.parameters({})
assert tool.parameters({"query": "frogs"}) == {"query": "frogs"}
tool = api_instance.tools[1]
assert tool.name == "save_memory"
assert tool.description == "Save a memory context."
with pytest.raises(
vol.Invalid, match=re.escape("required key not provided @ data['context']")
):
tool.parameters({})
assert tool.parameters({"context": {"fact": "User was born in February"}}) == {
"context": {"fact": "User was born in February"}
}
async def test_call_tool(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test calling an MCP Tool through the LLM API."""
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
tools=[SEARCH_MEMORY_TOOL]
)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.LOADED
apis = llm.async_get_apis(hass)
api = next(iter([api for api in apis if api.name == TEST_API_NAME]))
assert api
api_instance = await api.async_get_api_instance(create_llm_context())
assert len(api_instance.tools) == 1
tool = api_instance.tools[0]
assert tool.name == "search_memory"
mock_mcp_client.return_value.call_tool.return_value = CallToolResult(
content=[TextContent(type="text", text="User was born in February")]
)
result = await tool.async_call(
hass,
llm.ToolInput(
tool_name="search_memory", tool_args={"query": "User's birth month"}
),
create_llm_context(),
)
assert result == {
"content": [{"text": "User was born in February", "type": "text"}]
}
async def test_call_tool_fails(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test handling an MCP Tool call failure."""
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
tools=[SEARCH_MEMORY_TOOL]
)
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.LOADED
apis = llm.async_get_apis(hass)
api = next(iter([api for api in apis if api.name == TEST_API_NAME]))
assert api
api_instance = await api.async_get_api_instance(create_llm_context())
assert len(api_instance.tools) == 1
tool = api_instance.tools[0]
assert tool.name == "search_memory"
mock_mcp_client.return_value.call_tool.side_effect = httpx.HTTPStatusError(
"Server error", request=None, response=httpx.Response(500)
)
with pytest.raises(
HomeAssistantError, match="Error when calling tool: Server error"
):
await tool.async_call(
hass,
llm.ToolInput(
tool_name="search_memory", tool_args={"query": "User's birth month"}
),
create_llm_context(),
)
async def test_convert_tool_schema_fails(
hass: HomeAssistant, config_entry: MockConfigEntry, mock_mcp_client: Mock
) -> None:
"""Test a failure converting an MCP tool schema to a Home Assistant schema."""
mock_mcp_client.return_value.list_tools.return_value = ListToolsResult(
tools=[SEARCH_MEMORY_TOOL]
)
with patch(
"homeassistant.components.mcp.coordinator.convert_to_voluptuous",
side_effect=ValueError,
):
await hass.config_entries.async_setup(config_entry.entry_id)
assert config_entry.state is ConfigEntryState.SETUP_RETRY