Add Ollama conversation agent (#113962)

* Add ollama conversation agent

* Change iot class

* Much better default template

* Slight adjustment to prompt

* Make casing consistent

* Switch to ollama Python fork

* Add prompt to tests

* Rename to "ollama"

* Download models in config flow

* Update homeassistant/components/ollama/config_flow.py

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen
2024-03-26 16:15:20 -05:00
committed by GitHub
parent f94f1fb826
commit 72fed878b4
15 changed files with 1382 additions and 0 deletions

View File

@ -933,6 +933,8 @@ build.json @home-assistant/supervisor
/homeassistant/components/octoprint/ @rfleming71
/tests/components/octoprint/ @rfleming71
/homeassistant/components/ohmconnect/ @robbiet480
/homeassistant/components/ollama/ @synesthesiam
/tests/components/ollama/ @synesthesiam
/homeassistant/components/ombi/ @larssont
/homeassistant/components/omnilogic/ @oliver84 @djtimca @gentoosu
/tests/components/omnilogic/ @oliver84 @djtimca @gentoosu

View File

@ -0,0 +1,266 @@
"""The Ollama integration."""
from __future__ import annotations
import asyncio
import logging
import time
from typing import Literal
import httpx
import ollama
from homeassistant.components import conversation
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_URL, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
from homeassistant.helpers import (
area_registry as ar,
config_validation as cv,
device_registry as dr,
entity_registry as er,
intent,
template,
)
from homeassistant.util import ulid
from .const import (
CONF_MAX_HISTORY,
CONF_MODEL,
CONF_PROMPT,
DEFAULT_MAX_HISTORY,
DEFAULT_PROMPT,
DEFAULT_TIMEOUT,
DOMAIN,
KEEP_ALIVE_FOREVER,
MAX_HISTORY_SECONDS,
)
from .models import ExposedEntity, MessageHistory, MessageRole
_LOGGER = logging.getLogger(__name__)
__all__ = [
"CONF_URL",
"CONF_PROMPT",
"CONF_MODEL",
"CONF_MAX_HISTORY",
"MAX_HISTORY_NO_LIMIT",
"DOMAIN",
]
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Ollama from a config entry."""
settings = {**entry.data, **entry.options}
client = ollama.AsyncClient(host=settings[CONF_URL])
try:
async with asyncio.timeout(DEFAULT_TIMEOUT):
await client.list()
except (TimeoutError, httpx.ConnectError) as err:
raise ConfigEntryNotReady(err) from err
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
conversation.async_set_agent(hass, entry, OllamaAgent(hass, entry))
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Ollama."""
hass.data[DOMAIN].pop(entry.entry_id)
conversation.async_unset_agent(hass, entry)
return True
class OllamaAgent(conversation.AbstractConversationAgent):
"""Ollama conversation agent."""
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.hass = hass
self.entry = entry
# conversation id -> message history
self._history: dict[str, MessageHistory] = {}
@property
def supported_languages(self) -> list[str] | Literal["*"]:
"""Return a list of supported languages."""
return MATCH_ALL
async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
settings = {**self.entry.data, **self.entry.options}
client = self.hass.data[DOMAIN][self.entry.entry_id]
conversation_id = user_input.conversation_id or ulid.ulid_now()
model = settings[CONF_MODEL]
# Look up message history
message_history: MessageHistory | None = None
message_history = self._history.get(conversation_id)
if message_history is None:
# New history
#
# Render prompt and error out early if there's a problem
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
try:
prompt = self._generate_prompt(raw_prompt)
_LOGGER.debug("Prompt: %s", prompt)
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem generating my prompt: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
message_history = MessageHistory(
timestamp=time.monotonic(),
messages=[
ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
],
)
self._history[conversation_id] = message_history
else:
# Bump timestamp so this conversation won't get cleaned up
message_history.timestamp = time.monotonic()
# Clean up old histories
self._prune_old_histories()
# Trim this message history to keep a maximum number of *user* messages
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
self._trim_history(message_history, max_messages)
# Add new user message
message_history.messages.append(
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
)
# Get response
try:
response = await client.chat(
model=model,
# Make a copy of the messages because we mutate the list later
messages=list(message_history.messages),
stream=False,
keep_alive=KEEP_ALIVE_FOREVER,
)
except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to the Ollama server: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
response_message = response["message"]
message_history.messages.append(
ollama.Message(
role=response_message["role"], content=response_message["content"]
)
)
# Create intent response
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response_message["content"])
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
def _prune_old_histories(self) -> None:
"""Remove old message histories."""
now = time.monotonic()
self._history = {
conversation_id: message_history
for conversation_id, message_history in self._history.items()
if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
}
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
"""Trims excess messages from a single history."""
if max_messages < 1:
# Keep all messages
return
if message_history.num_user_messages >= max_messages:
# Trim history but keep system prompt (first message).
# Every other message should be an assistant message, so keep 2x
# message objects.
num_keep = 2 * max_messages
drop_index = len(message_history.messages) - num_keep
message_history.messages = [
message_history.messages[0]
] + message_history.messages[drop_index:]
def _generate_prompt(self, raw_prompt: str) -> str:
"""Generate a prompt for the user."""
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
"ha_language": self.hass.config.language,
"exposed_entities": self._get_exposed_entities(),
},
parse_result=False,
)
def _get_exposed_entities(self) -> list[ExposedEntity]:
"""Get state list of exposed entities."""
area_registry = ar.async_get(self.hass)
entity_registry = er.async_get(self.hass)
device_registry = dr.async_get(self.hass)
exposed_entities = []
exposed_states = [
state
for state in self.hass.states.async_all()
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
]
for state in exposed_states:
entity = entity_registry.async_get(state.entity_id)
names = [state.name]
area_names = []
if entity is not None:
# Add aliases
names.extend(entity.aliases)
if entity.area_id and (
area := area_registry.async_get_area(entity.area_id)
):
# Entity is in area
area_names.append(area.name)
area_names.extend(area.aliases)
elif entity.device_id and (
device := device_registry.async_get(entity.device_id)
):
# Check device area
if device.area_id and (
area := area_registry.async_get_area(device.area_id)
):
area_names.append(area.name)
area_names.extend(area.aliases)
exposed_entities.append(
ExposedEntity(
entity_id=state.entity_id,
state=state,
names=names,
area_names=area_names,
)
)
return exposed_entities

View File

@ -0,0 +1,245 @@
"""Config flow for Ollama integration."""
from __future__ import annotations
import asyncio
import logging
import sys
from types import MappingProxyType
from typing import Any
import httpx
import ollama
import voluptuous as vol
from homeassistant.config_entries import (
ConfigEntry,
ConfigFlow,
ConfigFlowResult,
OptionsFlow,
)
from homeassistant.const import CONF_URL
from homeassistant.helpers.selector import (
NumberSelector,
NumberSelectorConfig,
NumberSelectorMode,
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
TemplateSelector,
TextSelector,
TextSelectorConfig,
TextSelectorType,
)
from .const import (
CONF_MAX_HISTORY,
CONF_MODEL,
CONF_PROMPT,
DEFAULT_MAX_HISTORY,
DEFAULT_MODEL,
DEFAULT_PROMPT,
DEFAULT_TIMEOUT,
DOMAIN,
MODEL_NAMES,
)
_LOGGER = logging.getLogger(__name__)
STEP_USER_DATA_SCHEMA = vol.Schema(
{
vol.Required(CONF_URL): TextSelector(
TextSelectorConfig(type=TextSelectorType.URL)
),
}
)
class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Ollama."""
VERSION = 1
def __init__(self) -> None:
"""Initialize config flow."""
self.url: str | None = None
self.model: str | None = None
self.client: ollama.AsyncClient | None = None
self.download_task: asyncio.Task | None = None
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the initial step."""
user_input = user_input or {}
self.url = user_input.get(CONF_URL, self.url)
self.model = user_input.get(CONF_MODEL, self.model)
if self.url is None:
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, last_step=False
)
errors = {}
try:
self.client = ollama.AsyncClient(host=self.url)
async with asyncio.timeout(DEFAULT_TIMEOUT):
response = await self.client.list()
downloaded_models: set[str] = {
model_info["model"] for model_info in response.get("models", [])
}
except (TimeoutError, httpx.ConnectError):
errors["base"] = "cannot_connect"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
if errors:
return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
)
if self.model is None:
# Show models that have been downloaded first, followed by all known
# models (only latest tags).
models_to_list = [
SelectOptionDict(label=f"{m} (downloaded)", value=m)
for m in sorted(downloaded_models)
] + [
SelectOptionDict(label=m, value=f"{m}:latest")
for m in sorted(MODEL_NAMES)
if m not in downloaded_models
]
model_step_schema = vol.Schema(
{
vol.Required(
CONF_MODEL, description={"suggested_value": DEFAULT_MODEL}
): SelectSelector(
SelectSelectorConfig(options=models_to_list, custom_value=True)
),
}
)
return self.async_show_form(
step_id="user",
data_schema=model_step_schema,
)
if self.model not in downloaded_models:
# Ollama server needs to download model first
return await self.async_step_download()
return self.async_create_entry(
title=_get_title(self.model),
data={CONF_URL: self.url, CONF_MODEL: self.model},
)
async def async_step_download(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Step to wait for Ollama server to download a model."""
assert self.model is not None
assert self.client is not None
if self.download_task is None:
# Tell Ollama server to pull the model.
# The task will block until the model and metadata are fully
# downloaded.
self.download_task = self.hass.async_create_background_task(
self.client.pull(self.model), f"Downloading {self.model}"
)
if self.download_task.done():
if err := self.download_task.exception():
_LOGGER.exception("Unexpected error while downloading model: %s", err)
return self.async_show_progress_done(next_step_id="failed")
return self.async_show_progress_done(next_step_id="finish")
return self.async_show_progress(
step_id="download",
progress_action="download",
progress_task=self.download_task,
)
async def async_step_finish(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Step after model downloading has succeeded."""
assert self.url is not None
assert self.model is not None
return self.async_create_entry(
title=_get_title(self.model),
data={CONF_URL: self.url, CONF_MODEL: self.model},
)
async def async_step_failed(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Step after model downloading has failed."""
return self.async_abort(reason="download_failed")
@staticmethod
def async_get_options_flow(
config_entry: ConfigEntry,
) -> OptionsFlow:
"""Create the options flow."""
return OllamaOptionsFlow(config_entry)
class OllamaOptionsFlow(OptionsFlow):
"""Ollama options flow."""
def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow."""
self.config_entry = config_entry
self.url: str = self.config_entry.data[CONF_URL]
self.model: str = self.config_entry.data[CONF_MODEL]
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Manage the options."""
if user_input is not None:
return self.async_create_entry(
title=_get_title(self.model), data=user_input
)
options = self.config_entry.options or MappingProxyType({})
schema = ollama_config_option_schema(options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
)
def ollama_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
"""Ollama options schema."""
return {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
): TemplateSelector(),
vol.Optional(
CONF_MAX_HISTORY,
description={
"suggested_value": options.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)
},
): NumberSelector(
NumberSelectorConfig(
min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
)
),
}
def _get_title(model: str) -> str:
"""Get title for config entry."""
if model.endswith(":latest"):
model = model.split(":", maxsplit=1)[0]
return model

View File

@ -0,0 +1,114 @@
"""Constants for the Ollama integration."""
DOMAIN = "ollama"
CONF_MODEL = "model"
CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """{%- set used_domains = set([
"binary_sensor",
"climate",
"cover",
"fan",
"light",
"lock",
"sensor",
"switch",
"weather",
]) %}
{%- set used_attributes = set([
"temperature",
"current_temperature",
"temperature_unit",
"brightness",
"humidity",
"unit_of_measurement",
"device_class",
"current_position",
"percentage",
]) %}
This smart home is controlled by Home Assistant.
The current time is {{ now().strftime("%X") }}.
Today's date is {{ now().strftime("%x") }}.
An overview of the areas and the devices in this smart home:
```yaml
{%- for entity in exposed_entities: %}
{%- if entity.domain not in used_domains: %}
{%- continue %}
{%- endif %}
- domain: {{ entity.domain }}
{%- if entity.names | length == 1: %}
name: {{ entity.names[0] }}
{%- else: %}
names:
{%- for name in entity.names: %}
- {{ name }}
{%- endfor %}
{%- endif %}
{%- if entity.area_names | length == 1: %}
area: {{ entity.area_names[0] }}
{%- elif entity.area_names: %}
areas:
{%- for area_name in entity.area_names: %}
- {{ area_name }}
{%- endfor %}
{%- endif %}
state: {{ entity.state.state }}
{%- set attributes_key_printed = False %}
{%- for attr_name, attr_value in entity.state.attributes.items(): %}
{%- if attr_name in used_attributes: %}
{%- if not attributes_key_printed: %}
attributes:
{%- set attributes_key_printed = True %}
{%- endif %}
{{ attr_name }}: {{ attr_value }}
{%- endif %}
{%- endfor %}
{%- endfor %}
```
Answer the user's questions using the information about this smart home.
Keep your answers brief and do not apologize."""
KEEP_ALIVE_FOREVER = -1
DEFAULT_TIMEOUT = 5.0 # seconds
CONF_MAX_HISTORY = "max_history"
DEFAULT_MAX_HISTORY = 20
MAX_HISTORY_SECONDS = 60 * 60 # 1 hour
MODEL_NAMES = [ # https://ollama.com/library
"gemma",
"llama2",
"mistral",
"mixtral",
"llava",
"neural-chat",
"codellama",
"dolphin-mixtral",
"qwen",
"llama2-uncensored",
"mistral-openorca",
"deepseek-coder",
"nous-hermes2",
"phi",
"orca-mini",
"dolphin-mistral",
"wizard-vicuna-uncensored",
"vicuna",
"tinydolphin",
"llama2-chinese",
"nomic-embed-text",
"openhermes",
"zephyr",
"tinyllama",
"openchat",
"wizardcoder",
"starcoder",
"phind-codellama",
"starcoder2",
]
DEFAULT_MODEL = "llama2:latest"

View File

@ -0,0 +1,11 @@
{
"domain": "ollama",
"name": "Ollama",
"codeowners": ["@synesthesiam"],
"config_flow": true,
"dependencies": ["conversation"],
"documentation": "https://www.home-assistant.io/integrations/ollama",
"integration_type": "service",
"iot_class": "local_polling",
"requirements": ["ollama-hass==0.1.7"]
}

View File

@ -0,0 +1,47 @@
"""Models for Ollama integration."""
from dataclasses import dataclass
from enum import StrEnum
from functools import cached_property
import ollama
from homeassistant.core import State
class MessageRole(StrEnum):
"""Role of a chat message."""
SYSTEM = "system" # prompt
USER = "user"
@dataclass
class MessageHistory:
"""Chat message history."""
timestamp: float
"""Timestamp of last use in seconds."""
messages: list[ollama.Message]
"""List of message history, including system prompt and assistant responses."""
@property
def num_user_messages(self) -> int:
"""Return a count of user messages."""
return sum(m["role"] == MessageRole.USER for m in self.messages)
@dataclass(frozen=True)
class ExposedEntity:
"""Relevant information about an exposed entity."""
entity_id: str
state: State
names: list[str]
area_names: list[str]
@cached_property
def domain(self) -> str:
"""Get domain from entity id."""
return self.entity_id.split(".", maxsplit=1)[0]

View File

@ -0,0 +1,33 @@
{
"config": {
"step": {
"user": {
"data": {
"url": "[%key:common::config_flow::data::url%]",
"model": "Model"
}
},
"download": {
"title": "Downloading model"
}
},
"error": {
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
"download_failed": "Model downloading failed",
"unknown": "[%key:common::config_flow::error::unknown%]"
},
"progress": {
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
}
},
"options": {
"step": {
"init": {
"data": {
"prompt": "Prompt template",
"max_history": "Max history messages"
}
}
}
}
}

View File

@ -360,6 +360,7 @@ FLOWS = {
"nzbget",
"obihai",
"octoprint",
"ollama",
"omnilogic",
"oncue",
"ondilo_ico",

View File

@ -4136,6 +4136,12 @@
"config_flow": false,
"iot_class": "cloud_polling"
},
"ollama": {
"name": "Ollama",
"integration_type": "service",
"config_flow": true,
"iot_class": "local_polling"
},
"ombi": {
"name": "Ombi",
"integration_type": "hub",

View File

@ -1436,6 +1436,9 @@ odp-amsterdam==6.0.1
# homeassistant.components.oem
oemthermostat==1.1.1
# homeassistant.components.ollama
ollama-hass==0.1.7
# homeassistant.components.omnilogic
omnilogic==0.4.5

View File

@ -1148,6 +1148,9 @@ objgraph==3.5.0
# homeassistant.components.garages_amsterdam
odp-amsterdam==6.0.1
# homeassistant.components.ollama
ollama-hass==0.1.7
# homeassistant.components.omnilogic
omnilogic==0.4.5

View File

@ -0,0 +1,14 @@
"""Tests for the Ollama integration."""
from homeassistant.components import ollama
from homeassistant.components.ollama.const import DEFAULT_PROMPT
TEST_USER_DATA = {
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: "test model",
}
TEST_OPTIONS = {
ollama.CONF_PROMPT: DEFAULT_PROMPT,
ollama.CONF_MAX_HISTORY: 2,
}

View File

@ -0,0 +1,37 @@
"""Tests Ollama integration."""
from unittest.mock import patch
import pytest
from homeassistant.components import ollama
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import TEST_OPTIONS, TEST_USER_DATA
from tests.common import MockConfigEntry
@pytest.fixture
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"""Mock a config entry."""
entry = MockConfigEntry(
domain=ollama.DOMAIN,
data=TEST_USER_DATA,
options=TEST_OPTIONS,
)
entry.add_to_hass(hass)
return entry
@pytest.fixture
async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfigEntry):
"""Initialize integration."""
assert await async_setup_component(hass, "homeassistant", {})
with patch(
"ollama.AsyncClient.list",
):
assert await async_setup_component(hass, ollama.DOMAIN, {})
await hass.async_block_till_done()

View File

@ -0,0 +1,234 @@
"""Test the Ollama config flow."""
import asyncio
from unittest.mock import patch
from httpx import ConnectError
import pytest
from homeassistant import config_entries
from homeassistant.components import ollama
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from tests.common import MockConfigEntry
TEST_MODEL = "test_model:latest"
async def test_form(hass: HomeAssistant) -> None:
"""Test flow when the model is already downloaded."""
# Pretend we already set up a config entry.
hass.config.components.add(ollama.DOMAIN)
MockConfigEntry(
domain=ollama.DOMAIN,
state=config_entries.ConfigEntryState.LOADED,
).add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["errors"] is None
with (
patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
# test model is already "downloaded"
return_value={"models": [{"model": TEST_MODEL}]},
),
patch(
"homeassistant.components.ollama.async_setup_entry",
return_value=True,
) as mock_setup_entry,
):
# Step 1: URL
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
)
await hass.async_block_till_done()
# Step 2: model
assert result2["type"] == FlowResultType.FORM
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
)
await hass.async_block_till_done()
assert result3["type"] == FlowResultType.CREATE_ENTRY
assert result3["data"] == {
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: TEST_MODEL,
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_form_need_download(hass: HomeAssistant) -> None:
"""Test flow when a model needs to be downloaded."""
# Pretend we already set up a config entry.
hass.config.components.add(ollama.DOMAIN)
MockConfigEntry(
domain=ollama.DOMAIN,
state=config_entries.ConfigEntryState.LOADED,
).add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.FORM
assert result["errors"] is None
pull_ready = asyncio.Event()
pull_called = asyncio.Event()
pull_model: str | None = None
async def pull(self, model: str, *args, **kwargs) -> None:
nonlocal pull_model
async with asyncio.timeout(1):
await pull_ready.wait()
pull_model = model
pull_called.set()
with (
patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
# No models are downloaded
return_value={},
),
patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull",
pull,
),
patch(
"homeassistant.components.ollama.async_setup_entry",
return_value=True,
) as mock_setup_entry,
):
# Step 1: URL
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
)
await hass.async_block_till_done()
# Step 2: model
assert result2["type"] == FlowResultType.FORM
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
)
await hass.async_block_till_done()
# Step 3: download
assert result3["type"] == FlowResultType.SHOW_PROGRESS
result4 = await hass.config_entries.flow.async_configure(
result3["flow_id"],
)
await hass.async_block_till_done()
# Run again without the task finishing.
# We should still be downloading.
assert result4["type"] == FlowResultType.SHOW_PROGRESS
result4 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
)
await hass.async_block_till_done()
assert result4["type"] == FlowResultType.SHOW_PROGRESS
# Signal fake pull method to complete
pull_ready.set()
async with asyncio.timeout(1):
await pull_called.wait()
assert pull_model == TEST_MODEL
# Step 4: finish
result5 = await hass.config_entries.flow.async_configure(
result4["flow_id"],
)
assert result5["type"] == FlowResultType.CREATE_ENTRY
assert result5["data"] == {
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: TEST_MODEL,
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_options(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test the options form."""
options_flow = await hass.config_entries.options.async_init(
mock_config_entry.entry_id
)
options = await hass.config_entries.options.async_configure(
options_flow["flow_id"],
{ollama.CONF_PROMPT: "test prompt", ollama.CONF_MAX_HISTORY: 100},
)
await hass.async_block_till_done()
assert options["type"] == FlowResultType.CREATE_ENTRY
assert options["data"] == {
ollama.CONF_PROMPT: "test prompt",
ollama.CONF_MAX_HISTORY: 100,
}
@pytest.mark.parametrize(
("side_effect", "error"),
[
(ConnectError(message=""), "cannot_connect"),
(RuntimeError(), "unknown"),
],
)
async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
"""Test we handle errors."""
result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
)
with patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
side_effect=side_effect,
):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
)
assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {"base": error}
async def test_download_error(hass: HomeAssistant) -> None:
"""Test we handle errors while downloading a model."""
result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
)
with (
patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
return_value={},
),
patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull",
side_effect=RuntimeError(),
),
):
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
)
await hass.async_block_till_done()
assert result2["type"] == FlowResultType.FORM
result3 = await hass.config_entries.flow.async_configure(
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
)
await hass.async_block_till_done()
assert result3["type"] == FlowResultType.SHOW_PROGRESS
result4 = await hass.config_entries.flow.async_configure(result3["flow_id"])
await hass.async_block_till_done()
assert result4["type"] == FlowResultType.ABORT
assert result4["reason"] == "download_failed"

View File

@ -0,0 +1,366 @@
"""Tests for the Ollama integration."""
from unittest.mock import AsyncMock, patch
from httpx import ConnectError
from ollama import Message, ResponseError
import pytest
from homeassistant.components import conversation, ollama
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import (
area_registry as ar,
device_registry as dr,
entity_registry as er,
intent,
)
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
async def test_chat(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
area_registry: ar.AreaRegistry,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test that the chat function is called with the appropriate arguments."""
# Create some areas, devices, and entities
area_kitchen = area_registry.async_get_or_create("kitchen_id")
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
area_bedroom = area_registry.async_get_or_create("bedroom_id")
area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom")
area_office = area_registry.async_get_or_create("office_id")
area_office = area_registry.async_update(area_office.id, name="office")
entry = MockConfigEntry()
entry.add_to_hass(hass)
kitchen_device = device_registry.async_get_or_create(
config_entry_id=entry.entry_id,
connections=set(),
identifiers={("demo", "id-1234")},
)
device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234")
kitchen_light = entity_registry.async_update_entity(
kitchen_light.entity_id, device_id=kitchen_device.id
)
hass.states.async_set(
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
)
bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678")
bedroom_light = entity_registry.async_update_entity(
bedroom_light.entity_id, area_id=area_bedroom.id
)
hass.states.async_set(
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
)
# Hide the office light
office_light = entity_registry.async_get_or_create("light", "demo", "ABCD")
office_light = entity_registry.async_update_entity(
office_light.entity_id, area_id=area_office.id
)
hass.states.async_set(
office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"}
)
async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False)
with patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
) as mock_chat:
result = await conversation.async_converse(
hass,
"test message",
None,
Context(),
agent_id=mock_config_entry.entry_id,
)
assert mock_chat.call_count == 1
args = mock_chat.call_args.kwargs
prompt = args["messages"][0]["content"]
assert args["model"] == "test model"
assert args["messages"] == [
Message({"role": "system", "content": prompt}),
Message({"role": "user", "content": "test message"}),
]
# Verify only exposed devices/areas are in prompt
assert "kitchen light" in prompt
assert "bedroom light" in prompt
assert "office light" not in prompt
assert "office" not in prompt
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert result.response.speech["plain"]["speech"] == "test response"
async def test_message_history_trimming(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that a single message history is trimmed according to the config."""
response_idx = 0
def response(*args, **kwargs) -> dict:
nonlocal response_idx
response_idx += 1
return {"message": {"role": "assistant", "content": f"response {response_idx}"}}
with patch(
"ollama.AsyncClient.chat",
side_effect=response,
) as mock_chat:
# mock_init_component sets "max_history" to 2
for i in range(5):
result = await conversation.async_converse(
hass,
f"message {i+1}",
conversation_id="1234",
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert mock_chat.call_count == 5
args = mock_chat.call_args_list
prompt = args[0].kwargs["messages"][0]["content"]
# system + user-1
assert len(args[0].kwargs["messages"]) == 2
assert args[0].kwargs["messages"][1]["content"] == "message 1"
# Full history
# system + user-1 + assistant-1 + user-2
assert len(args[1].kwargs["messages"]) == 4
assert args[1].kwargs["messages"][0]["role"] == "system"
assert args[1].kwargs["messages"][0]["content"] == prompt
assert args[1].kwargs["messages"][1]["role"] == "user"
assert args[1].kwargs["messages"][1]["content"] == "message 1"
assert args[1].kwargs["messages"][2]["role"] == "assistant"
assert args[1].kwargs["messages"][2]["content"] == "response 1"
assert args[1].kwargs["messages"][3]["role"] == "user"
assert args[1].kwargs["messages"][3]["content"] == "message 2"
# Full history
# system + user-1 + assistant-1 + user-2 + assistant-2 + user-3
assert len(args[2].kwargs["messages"]) == 6
assert args[2].kwargs["messages"][0]["role"] == "system"
assert args[2].kwargs["messages"][0]["content"] == prompt
assert args[2].kwargs["messages"][1]["role"] == "user"
assert args[2].kwargs["messages"][1]["content"] == "message 1"
assert args[2].kwargs["messages"][2]["role"] == "assistant"
assert args[2].kwargs["messages"][2]["content"] == "response 1"
assert args[2].kwargs["messages"][3]["role"] == "user"
assert args[2].kwargs["messages"][3]["content"] == "message 2"
assert args[2].kwargs["messages"][4]["role"] == "assistant"
assert args[2].kwargs["messages"][4]["content"] == "response 2"
assert args[2].kwargs["messages"][5]["role"] == "user"
assert args[2].kwargs["messages"][5]["content"] == "message 3"
# Trimmed down to two user messages.
# system + user-2 + assistant-2 + user-3 + assistant-3 + user-4
assert len(args[3].kwargs["messages"]) == 6
assert args[3].kwargs["messages"][0]["role"] == "system"
assert args[3].kwargs["messages"][0]["content"] == prompt
assert args[3].kwargs["messages"][1]["role"] == "user"
assert args[3].kwargs["messages"][1]["content"] == "message 2"
assert args[3].kwargs["messages"][2]["role"] == "assistant"
assert args[3].kwargs["messages"][2]["content"] == "response 2"
assert args[3].kwargs["messages"][3]["role"] == "user"
assert args[3].kwargs["messages"][3]["content"] == "message 3"
assert args[3].kwargs["messages"][4]["role"] == "assistant"
assert args[3].kwargs["messages"][4]["content"] == "response 3"
assert args[3].kwargs["messages"][5]["role"] == "user"
assert args[3].kwargs["messages"][5]["content"] == "message 4"
# Trimmed down to two user messages.
# system + user-3 + assistant-3 + user-4 + assistant-4 + user-5
assert len(args[3].kwargs["messages"]) == 6
assert args[4].kwargs["messages"][0]["role"] == "system"
assert args[4].kwargs["messages"][0]["content"] == prompt
assert args[4].kwargs["messages"][1]["role"] == "user"
assert args[4].kwargs["messages"][1]["content"] == "message 3"
assert args[4].kwargs["messages"][2]["role"] == "assistant"
assert args[4].kwargs["messages"][2]["content"] == "response 3"
assert args[4].kwargs["messages"][3]["role"] == "user"
assert args[4].kwargs["messages"][3]["content"] == "message 4"
assert args[4].kwargs["messages"][4]["role"] == "assistant"
assert args[4].kwargs["messages"][4]["content"] == "response 4"
assert args[4].kwargs["messages"][5]["role"] == "user"
assert args[4].kwargs["messages"][5]["content"] == "message 5"
async def test_message_history_pruning(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that old message histories are pruned."""
with patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
):
# Create 3 different message histories
conversation_ids: list[str] = []
for i in range(3):
result = await conversation.async_converse(
hass,
f"message {i+1}",
conversation_id=None,
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert isinstance(result.conversation_id, str)
conversation_ids.append(result.conversation_id)
agent = await conversation._get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert isinstance(agent, ollama.OllamaAgent)
assert len(agent._history) == 3
assert agent._history.keys() == set(conversation_ids)
# Modify the timestamps of the first 2 histories so they will be pruned
# on the next cycle.
for conversation_id in conversation_ids[:2]:
# Move back 2 hours
agent._history[conversation_id].timestamp -= 2 * 60 * 60
# Next cycle
result = await conversation.async_converse(
hass,
"test message",
conversation_id=None,
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
# Only the most recent histories should remain
assert len(agent._history) == 2
assert conversation_ids[-1] in agent._history
assert result.conversation_id in agent._history
async def test_message_history_unlimited(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that message history is not trimmed when max_history = 0."""
conversation_id = "1234"
with (
patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
),
patch.object(mock_config_entry, "options", {ollama.CONF_MAX_HISTORY: 0}),
):
for i in range(100):
result = await conversation.async_converse(
hass,
f"message {i+1}",
conversation_id=conversation_id,
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
agent = await conversation._get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert isinstance(agent, ollama.OllamaAgent)
assert len(agent._history) == 1
assert conversation_id in agent._history
assert agent._history[conversation_id].num_user_messages == 100
async def test_error_handling(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test error handling during converse."""
with patch(
"ollama.AsyncClient.chat",
new_callable=AsyncMock,
side_effect=ResponseError("test error"),
):
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test that template error handling works."""
hass.config_entries.async_update_entry(
mock_config_entry,
options={
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
},
)
with patch(
"ollama.AsyncClient.list",
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
async def test_conversation_agent(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_init_component,
) -> None:
"""Test OllamaAgent."""
agent = await conversation._get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert agent.supported_languages == MATCH_ALL
@pytest.mark.parametrize(
("side_effect", "error"),
[
(ConnectError(message="Connect error"), "Connect error"),
(RuntimeError("Runtime error"), "Runtime error"),
],
)
async def test_init_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, caplog, side_effect, error
) -> None:
"""Test initialization errors."""
with patch(
"ollama.AsyncClient.list",
side_effect=side_effect,
):
assert await async_setup_component(hass, ollama.DOMAIN, {})
await hass.async_block_till_done()
assert error in caplog.text