mirror of
https://github.com/home-assistant/core.git
synced 2025-06-25 01:21:51 +02:00
Add Ollama conversation agent (#113962)
* Add ollama conversation agent * Change iot class * Much better default template * Slight adjustment to prompt * Make casing consistent * Switch to ollama Python fork * Add prompt to tests * Rename to "ollama" * Download models in config flow * Update homeassistant/components/ollama/config_flow.py --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
@ -933,6 +933,8 @@ build.json @home-assistant/supervisor
|
||||
/homeassistant/components/octoprint/ @rfleming71
|
||||
/tests/components/octoprint/ @rfleming71
|
||||
/homeassistant/components/ohmconnect/ @robbiet480
|
||||
/homeassistant/components/ollama/ @synesthesiam
|
||||
/tests/components/ollama/ @synesthesiam
|
||||
/homeassistant/components/ombi/ @larssont
|
||||
/homeassistant/components/omnilogic/ @oliver84 @djtimca @gentoosu
|
||||
/tests/components/omnilogic/ @oliver84 @djtimca @gentoosu
|
||||
|
266
homeassistant/components/ollama/__init__.py
Normal file
266
homeassistant/components/ollama/__init__.py
Normal file
@ -0,0 +1,266 @@
|
||||
"""The Ollama integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import ollama
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_should_expose
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_URL, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryNotReady, TemplateError
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
config_validation as cv,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
intent,
|
||||
template,
|
||||
)
|
||||
from homeassistant.util import ulid
|
||||
|
||||
from .const import (
|
||||
CONF_MAX_HISTORY,
|
||||
CONF_MODEL,
|
||||
CONF_PROMPT,
|
||||
DEFAULT_MAX_HISTORY,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_TIMEOUT,
|
||||
DOMAIN,
|
||||
KEEP_ALIVE_FOREVER,
|
||||
MAX_HISTORY_SECONDS,
|
||||
)
|
||||
from .models import ExposedEntity, MessageHistory, MessageRole
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"CONF_URL",
|
||||
"CONF_PROMPT",
|
||||
"CONF_MODEL",
|
||||
"CONF_MAX_HISTORY",
|
||||
"MAX_HISTORY_NO_LIMIT",
|
||||
"DOMAIN",
|
||||
]
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up Ollama from a config entry."""
|
||||
settings = {**entry.data, **entry.options}
|
||||
client = ollama.AsyncClient(host=settings[CONF_URL])
|
||||
try:
|
||||
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
||||
await client.list()
|
||||
except (TimeoutError, httpx.ConnectError) as err:
|
||||
raise ConfigEntryNotReady(err) from err
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
|
||||
|
||||
conversation.async_set_agent(hass, entry, OllamaAgent(hass, entry))
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload Ollama."""
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
conversation.async_unset_agent(hass, entry)
|
||||
return True
|
||||
|
||||
|
||||
class OllamaAgent(conversation.AbstractConversationAgent):
|
||||
"""Ollama conversation agent."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||
"""Initialize the agent."""
|
||||
self.hass = hass
|
||||
self.entry = entry
|
||||
|
||||
# conversation id -> message history
|
||||
self._history: dict[str, MessageHistory] = {}
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||
"""Return a list of supported languages."""
|
||||
return MATCH_ALL
|
||||
|
||||
async def async_process(
|
||||
self, user_input: conversation.ConversationInput
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a sentence."""
|
||||
settings = {**self.entry.data, **self.entry.options}
|
||||
|
||||
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
||||
conversation_id = user_input.conversation_id or ulid.ulid_now()
|
||||
model = settings[CONF_MODEL]
|
||||
|
||||
# Look up message history
|
||||
message_history: MessageHistory | None = None
|
||||
message_history = self._history.get(conversation_id)
|
||||
if message_history is None:
|
||||
# New history
|
||||
#
|
||||
# Render prompt and error out early if there's a problem
|
||||
raw_prompt = settings.get(CONF_PROMPT, DEFAULT_PROMPT)
|
||||
try:
|
||||
prompt = self._generate_prompt(raw_prompt)
|
||||
_LOGGER.debug("Prompt: %s", prompt)
|
||||
except TemplateError as err:
|
||||
_LOGGER.error("Error rendering prompt: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem generating my prompt: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
message_history = MessageHistory(
|
||||
timestamp=time.monotonic(),
|
||||
messages=[
|
||||
ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
|
||||
],
|
||||
)
|
||||
self._history[conversation_id] = message_history
|
||||
else:
|
||||
# Bump timestamp so this conversation won't get cleaned up
|
||||
message_history.timestamp = time.monotonic()
|
||||
|
||||
# Clean up old histories
|
||||
self._prune_old_histories()
|
||||
|
||||
# Trim this message history to keep a maximum number of *user* messages
|
||||
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
|
||||
self._trim_history(message_history, max_messages)
|
||||
|
||||
# Add new user message
|
||||
message_history.messages.append(
|
||||
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
|
||||
)
|
||||
|
||||
# Get response
|
||||
try:
|
||||
response = await client.chat(
|
||||
model=model,
|
||||
# Make a copy of the messages because we mutate the list later
|
||||
messages=list(message_history.messages),
|
||||
stream=False,
|
||||
keep_alive=KEEP_ALIVE_FOREVER,
|
||||
)
|
||||
except (ollama.RequestError, ollama.ResponseError) as err:
|
||||
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_error(
|
||||
intent.IntentResponseErrorCode.UNKNOWN,
|
||||
f"Sorry, I had a problem talking to the Ollama server: {err}",
|
||||
)
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
response_message = response["message"]
|
||||
message_history.messages.append(
|
||||
ollama.Message(
|
||||
role=response_message["role"], content=response_message["content"]
|
||||
)
|
||||
)
|
||||
|
||||
# Create intent response
|
||||
intent_response = intent.IntentResponse(language=user_input.language)
|
||||
intent_response.async_set_speech(response_message["content"])
|
||||
return conversation.ConversationResult(
|
||||
response=intent_response, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
def _prune_old_histories(self) -> None:
|
||||
"""Remove old message histories."""
|
||||
now = time.monotonic()
|
||||
self._history = {
|
||||
conversation_id: message_history
|
||||
for conversation_id, message_history in self._history.items()
|
||||
if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
|
||||
}
|
||||
|
||||
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
|
||||
"""Trims excess messages from a single history."""
|
||||
if max_messages < 1:
|
||||
# Keep all messages
|
||||
return
|
||||
|
||||
if message_history.num_user_messages >= max_messages:
|
||||
# Trim history but keep system prompt (first message).
|
||||
# Every other message should be an assistant message, so keep 2x
|
||||
# message objects.
|
||||
num_keep = 2 * max_messages
|
||||
drop_index = len(message_history.messages) - num_keep
|
||||
message_history.messages = [
|
||||
message_history.messages[0]
|
||||
] + message_history.messages[drop_index:]
|
||||
|
||||
def _generate_prompt(self, raw_prompt: str) -> str:
|
||||
"""Generate a prompt for the user."""
|
||||
return template.Template(raw_prompt, self.hass).async_render(
|
||||
{
|
||||
"ha_name": self.hass.config.location_name,
|
||||
"ha_language": self.hass.config.language,
|
||||
"exposed_entities": self._get_exposed_entities(),
|
||||
},
|
||||
parse_result=False,
|
||||
)
|
||||
|
||||
def _get_exposed_entities(self) -> list[ExposedEntity]:
|
||||
"""Get state list of exposed entities."""
|
||||
area_registry = ar.async_get(self.hass)
|
||||
entity_registry = er.async_get(self.hass)
|
||||
device_registry = dr.async_get(self.hass)
|
||||
|
||||
exposed_entities = []
|
||||
exposed_states = [
|
||||
state
|
||||
for state in self.hass.states.async_all()
|
||||
if async_should_expose(self.hass, conversation.DOMAIN, state.entity_id)
|
||||
]
|
||||
|
||||
for state in exposed_states:
|
||||
entity = entity_registry.async_get(state.entity_id)
|
||||
names = [state.name]
|
||||
area_names = []
|
||||
|
||||
if entity is not None:
|
||||
# Add aliases
|
||||
names.extend(entity.aliases)
|
||||
if entity.area_id and (
|
||||
area := area_registry.async_get_area(entity.area_id)
|
||||
):
|
||||
# Entity is in area
|
||||
area_names.append(area.name)
|
||||
area_names.extend(area.aliases)
|
||||
elif entity.device_id and (
|
||||
device := device_registry.async_get(entity.device_id)
|
||||
):
|
||||
# Check device area
|
||||
if device.area_id and (
|
||||
area := area_registry.async_get_area(device.area_id)
|
||||
):
|
||||
area_names.append(area.name)
|
||||
area_names.extend(area.aliases)
|
||||
|
||||
exposed_entities.append(
|
||||
ExposedEntity(
|
||||
entity_id=state.entity_id,
|
||||
state=state,
|
||||
names=names,
|
||||
area_names=area_names,
|
||||
)
|
||||
)
|
||||
|
||||
return exposed_entities
|
245
homeassistant/components/ollama/config_flow.py
Normal file
245
homeassistant/components/ollama/config_flow.py
Normal file
@ -0,0 +1,245 @@
|
||||
"""Config flow for Ollama integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import ollama
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import (
|
||||
ConfigEntry,
|
||||
ConfigFlow,
|
||||
ConfigFlowResult,
|
||||
OptionsFlow,
|
||||
)
|
||||
from homeassistant.const import CONF_URL
|
||||
from homeassistant.helpers.selector import (
|
||||
NumberSelector,
|
||||
NumberSelectorConfig,
|
||||
NumberSelectorMode,
|
||||
SelectOptionDict,
|
||||
SelectSelector,
|
||||
SelectSelectorConfig,
|
||||
TemplateSelector,
|
||||
TextSelector,
|
||||
TextSelectorConfig,
|
||||
TextSelectorType,
|
||||
)
|
||||
|
||||
from .const import (
|
||||
CONF_MAX_HISTORY,
|
||||
CONF_MODEL,
|
||||
CONF_PROMPT,
|
||||
DEFAULT_MAX_HISTORY,
|
||||
DEFAULT_MODEL,
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_TIMEOUT,
|
||||
DOMAIN,
|
||||
MODEL_NAMES,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_URL): TextSelector(
|
||||
TextSelectorConfig(type=TextSelectorType.URL)
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Ollama."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize config flow."""
|
||||
self.url: str | None = None
|
||||
self.model: str | None = None
|
||||
self.client: ollama.AsyncClient | None = None
|
||||
self.download_task: asyncio.Task | None = None
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle the initial step."""
|
||||
user_input = user_input or {}
|
||||
self.url = user_input.get(CONF_URL, self.url)
|
||||
self.model = user_input.get(CONF_MODEL, self.model)
|
||||
|
||||
if self.url is None:
|
||||
return self.async_show_form(
|
||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, last_step=False
|
||||
)
|
||||
|
||||
errors = {}
|
||||
|
||||
try:
|
||||
self.client = ollama.AsyncClient(host=self.url)
|
||||
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
||||
response = await self.client.list()
|
||||
|
||||
downloaded_models: set[str] = {
|
||||
model_info["model"] for model_info in response.get("models", [])
|
||||
}
|
||||
except (TimeoutError, httpx.ConnectError):
|
||||
errors["base"] = "cannot_connect"
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
errors["base"] = "unknown"
|
||||
|
||||
if errors:
|
||||
return self.async_show_form(
|
||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
|
||||
)
|
||||
|
||||
if self.model is None:
|
||||
# Show models that have been downloaded first, followed by all known
|
||||
# models (only latest tags).
|
||||
models_to_list = [
|
||||
SelectOptionDict(label=f"{m} (downloaded)", value=m)
|
||||
for m in sorted(downloaded_models)
|
||||
] + [
|
||||
SelectOptionDict(label=m, value=f"{m}:latest")
|
||||
for m in sorted(MODEL_NAMES)
|
||||
if m not in downloaded_models
|
||||
]
|
||||
model_step_schema = vol.Schema(
|
||||
{
|
||||
vol.Required(
|
||||
CONF_MODEL, description={"suggested_value": DEFAULT_MODEL}
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(options=models_to_list, custom_value=True)
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=model_step_schema,
|
||||
)
|
||||
|
||||
if self.model not in downloaded_models:
|
||||
# Ollama server needs to download model first
|
||||
return await self.async_step_download()
|
||||
|
||||
return self.async_create_entry(
|
||||
title=_get_title(self.model),
|
||||
data={CONF_URL: self.url, CONF_MODEL: self.model},
|
||||
)
|
||||
|
||||
async def async_step_download(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Step to wait for Ollama server to download a model."""
|
||||
assert self.model is not None
|
||||
assert self.client is not None
|
||||
|
||||
if self.download_task is None:
|
||||
# Tell Ollama server to pull the model.
|
||||
# The task will block until the model and metadata are fully
|
||||
# downloaded.
|
||||
self.download_task = self.hass.async_create_background_task(
|
||||
self.client.pull(self.model), f"Downloading {self.model}"
|
||||
)
|
||||
|
||||
if self.download_task.done():
|
||||
if err := self.download_task.exception():
|
||||
_LOGGER.exception("Unexpected error while downloading model: %s", err)
|
||||
return self.async_show_progress_done(next_step_id="failed")
|
||||
|
||||
return self.async_show_progress_done(next_step_id="finish")
|
||||
|
||||
return self.async_show_progress(
|
||||
step_id="download",
|
||||
progress_action="download",
|
||||
progress_task=self.download_task,
|
||||
)
|
||||
|
||||
async def async_step_finish(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Step after model downloading has succeeded."""
|
||||
assert self.url is not None
|
||||
assert self.model is not None
|
||||
|
||||
return self.async_create_entry(
|
||||
title=_get_title(self.model),
|
||||
data={CONF_URL: self.url, CONF_MODEL: self.model},
|
||||
)
|
||||
|
||||
async def async_step_failed(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Step after model downloading has failed."""
|
||||
return self.async_abort(reason="download_failed")
|
||||
|
||||
@staticmethod
|
||||
def async_get_options_flow(
|
||||
config_entry: ConfigEntry,
|
||||
) -> OptionsFlow:
|
||||
"""Create the options flow."""
|
||||
return OllamaOptionsFlow(config_entry)
|
||||
|
||||
|
||||
class OllamaOptionsFlow(OptionsFlow):
|
||||
"""Ollama options flow."""
|
||||
|
||||
def __init__(self, config_entry: ConfigEntry) -> None:
|
||||
"""Initialize options flow."""
|
||||
self.config_entry = config_entry
|
||||
self.url: str = self.config_entry.data[CONF_URL]
|
||||
self.model: str = self.config_entry.data[CONF_MODEL]
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Manage the options."""
|
||||
if user_input is not None:
|
||||
return self.async_create_entry(
|
||||
title=_get_title(self.model), data=user_input
|
||||
)
|
||||
|
||||
options = self.config_entry.options or MappingProxyType({})
|
||||
schema = ollama_config_option_schema(options)
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=vol.Schema(schema),
|
||||
)
|
||||
|
||||
|
||||
def ollama_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
|
||||
"""Ollama options schema."""
|
||||
return {
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_MAX_HISTORY,
|
||||
description={
|
||||
"suggested_value": options.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)
|
||||
},
|
||||
): NumberSelector(
|
||||
NumberSelectorConfig(
|
||||
min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _get_title(model: str) -> str:
|
||||
"""Get title for config entry."""
|
||||
if model.endswith(":latest"):
|
||||
model = model.split(":", maxsplit=1)[0]
|
||||
|
||||
return model
|
114
homeassistant/components/ollama/const.py
Normal file
114
homeassistant/components/ollama/const.py
Normal file
@ -0,0 +1,114 @@
|
||||
"""Constants for the Ollama integration."""
|
||||
|
||||
DOMAIN = "ollama"
|
||||
|
||||
CONF_MODEL = "model"
|
||||
CONF_PROMPT = "prompt"
|
||||
DEFAULT_PROMPT = """{%- set used_domains = set([
|
||||
"binary_sensor",
|
||||
"climate",
|
||||
"cover",
|
||||
"fan",
|
||||
"light",
|
||||
"lock",
|
||||
"sensor",
|
||||
"switch",
|
||||
"weather",
|
||||
]) %}
|
||||
{%- set used_attributes = set([
|
||||
"temperature",
|
||||
"current_temperature",
|
||||
"temperature_unit",
|
||||
"brightness",
|
||||
"humidity",
|
||||
"unit_of_measurement",
|
||||
"device_class",
|
||||
"current_position",
|
||||
"percentage",
|
||||
]) %}
|
||||
|
||||
This smart home is controlled by Home Assistant.
|
||||
The current time is {{ now().strftime("%X") }}.
|
||||
Today's date is {{ now().strftime("%x") }}.
|
||||
|
||||
An overview of the areas and the devices in this smart home:
|
||||
```yaml
|
||||
{%- for entity in exposed_entities: %}
|
||||
{%- if entity.domain not in used_domains: %}
|
||||
{%- continue %}
|
||||
{%- endif %}
|
||||
|
||||
- domain: {{ entity.domain }}
|
||||
{%- if entity.names | length == 1: %}
|
||||
name: {{ entity.names[0] }}
|
||||
{%- else: %}
|
||||
names:
|
||||
{%- for name in entity.names: %}
|
||||
- {{ name }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- if entity.area_names | length == 1: %}
|
||||
area: {{ entity.area_names[0] }}
|
||||
{%- elif entity.area_names: %}
|
||||
areas:
|
||||
{%- for area_name in entity.area_names: %}
|
||||
- {{ area_name }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
state: {{ entity.state.state }}
|
||||
{%- set attributes_key_printed = False %}
|
||||
{%- for attr_name, attr_value in entity.state.attributes.items(): %}
|
||||
{%- if attr_name in used_attributes: %}
|
||||
{%- if not attributes_key_printed: %}
|
||||
attributes:
|
||||
{%- set attributes_key_printed = True %}
|
||||
{%- endif %}
|
||||
{{ attr_name }}: {{ attr_value }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
```
|
||||
|
||||
Answer the user's questions using the information about this smart home.
|
||||
Keep your answers brief and do not apologize."""
|
||||
|
||||
KEEP_ALIVE_FOREVER = -1
|
||||
DEFAULT_TIMEOUT = 5.0 # seconds
|
||||
|
||||
CONF_MAX_HISTORY = "max_history"
|
||||
DEFAULT_MAX_HISTORY = 20
|
||||
|
||||
MAX_HISTORY_SECONDS = 60 * 60 # 1 hour
|
||||
|
||||
MODEL_NAMES = [ # https://ollama.com/library
|
||||
"gemma",
|
||||
"llama2",
|
||||
"mistral",
|
||||
"mixtral",
|
||||
"llava",
|
||||
"neural-chat",
|
||||
"codellama",
|
||||
"dolphin-mixtral",
|
||||
"qwen",
|
||||
"llama2-uncensored",
|
||||
"mistral-openorca",
|
||||
"deepseek-coder",
|
||||
"nous-hermes2",
|
||||
"phi",
|
||||
"orca-mini",
|
||||
"dolphin-mistral",
|
||||
"wizard-vicuna-uncensored",
|
||||
"vicuna",
|
||||
"tinydolphin",
|
||||
"llama2-chinese",
|
||||
"nomic-embed-text",
|
||||
"openhermes",
|
||||
"zephyr",
|
||||
"tinyllama",
|
||||
"openchat",
|
||||
"wizardcoder",
|
||||
"starcoder",
|
||||
"phind-codellama",
|
||||
"starcoder2",
|
||||
]
|
||||
DEFAULT_MODEL = "llama2:latest"
|
11
homeassistant/components/ollama/manifest.json
Normal file
11
homeassistant/components/ollama/manifest.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"domain": "ollama",
|
||||
"name": "Ollama",
|
||||
"codeowners": ["@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["conversation"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/ollama",
|
||||
"integration_type": "service",
|
||||
"iot_class": "local_polling",
|
||||
"requirements": ["ollama-hass==0.1.7"]
|
||||
}
|
47
homeassistant/components/ollama/models.py
Normal file
47
homeassistant/components/ollama/models.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""Models for Ollama integration."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from functools import cached_property
|
||||
|
||||
import ollama
|
||||
|
||||
from homeassistant.core import State
|
||||
|
||||
|
||||
class MessageRole(StrEnum):
|
||||
"""Role of a chat message."""
|
||||
|
||||
SYSTEM = "system" # prompt
|
||||
USER = "user"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageHistory:
|
||||
"""Chat message history."""
|
||||
|
||||
timestamp: float
|
||||
"""Timestamp of last use in seconds."""
|
||||
|
||||
messages: list[ollama.Message]
|
||||
"""List of message history, including system prompt and assistant responses."""
|
||||
|
||||
@property
|
||||
def num_user_messages(self) -> int:
|
||||
"""Return a count of user messages."""
|
||||
return sum(m["role"] == MessageRole.USER for m in self.messages)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExposedEntity:
|
||||
"""Relevant information about an exposed entity."""
|
||||
|
||||
entity_id: str
|
||||
state: State
|
||||
names: list[str]
|
||||
area_names: list[str]
|
||||
|
||||
@cached_property
|
||||
def domain(self) -> str:
|
||||
"""Get domain from entity id."""
|
||||
return self.entity_id.split(".", maxsplit=1)[0]
|
33
homeassistant/components/ollama/strings.json
Normal file
33
homeassistant/components/ollama/strings.json
Normal file
@ -0,0 +1,33 @@
|
||||
{
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
"data": {
|
||||
"url": "[%key:common::config_flow::data::url%]",
|
||||
"model": "Model"
|
||||
}
|
||||
},
|
||||
"download": {
|
||||
"title": "Downloading model"
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||
"download_failed": "Model downloading failed",
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||
},
|
||||
"progress": {
|
||||
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"data": {
|
||||
"prompt": "Prompt template",
|
||||
"max_history": "Max history messages"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -360,6 +360,7 @@ FLOWS = {
|
||||
"nzbget",
|
||||
"obihai",
|
||||
"octoprint",
|
||||
"ollama",
|
||||
"omnilogic",
|
||||
"oncue",
|
||||
"ondilo_ico",
|
||||
|
@ -4136,6 +4136,12 @@
|
||||
"config_flow": false,
|
||||
"iot_class": "cloud_polling"
|
||||
},
|
||||
"ollama": {
|
||||
"name": "Ollama",
|
||||
"integration_type": "service",
|
||||
"config_flow": true,
|
||||
"iot_class": "local_polling"
|
||||
},
|
||||
"ombi": {
|
||||
"name": "Ombi",
|
||||
"integration_type": "hub",
|
||||
|
@ -1436,6 +1436,9 @@ odp-amsterdam==6.0.1
|
||||
# homeassistant.components.oem
|
||||
oemthermostat==1.1.1
|
||||
|
||||
# homeassistant.components.ollama
|
||||
ollama-hass==0.1.7
|
||||
|
||||
# homeassistant.components.omnilogic
|
||||
omnilogic==0.4.5
|
||||
|
||||
|
@ -1148,6 +1148,9 @@ objgraph==3.5.0
|
||||
# homeassistant.components.garages_amsterdam
|
||||
odp-amsterdam==6.0.1
|
||||
|
||||
# homeassistant.components.ollama
|
||||
ollama-hass==0.1.7
|
||||
|
||||
# homeassistant.components.omnilogic
|
||||
omnilogic==0.4.5
|
||||
|
||||
|
14
tests/components/ollama/__init__.py
Normal file
14
tests/components/ollama/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""Tests for the Ollama integration."""
|
||||
|
||||
from homeassistant.components import ollama
|
||||
from homeassistant.components.ollama.const import DEFAULT_PROMPT
|
||||
|
||||
TEST_USER_DATA = {
|
||||
ollama.CONF_URL: "http://localhost:11434",
|
||||
ollama.CONF_MODEL: "test model",
|
||||
}
|
||||
|
||||
TEST_OPTIONS = {
|
||||
ollama.CONF_PROMPT: DEFAULT_PROMPT,
|
||||
ollama.CONF_MAX_HISTORY: 2,
|
||||
}
|
37
tests/components/ollama/conftest.py
Normal file
37
tests/components/ollama/conftest.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""Tests Ollama integration."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import ollama
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import TEST_OPTIONS, TEST_USER_DATA
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
"""Mock a config entry."""
|
||||
entry = MockConfigEntry(
|
||||
domain=ollama.DOMAIN,
|
||||
data=TEST_USER_DATA,
|
||||
options=TEST_OPTIONS,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_init_component(hass: HomeAssistant, mock_config_entry: MockConfigEntry):
|
||||
"""Initialize integration."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.list",
|
||||
):
|
||||
assert await async_setup_component(hass, ollama.DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
234
tests/components/ollama/test_config_flow.py
Normal file
234
tests/components/ollama/test_config_flow.py
Normal file
@ -0,0 +1,234 @@
|
||||
"""Test the Ollama config flow."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
from httpx import ConnectError
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components import ollama
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
TEST_MODEL = "test_model:latest"
|
||||
|
||||
|
||||
async def test_form(hass: HomeAssistant) -> None:
|
||||
"""Test flow when the model is already downloaded."""
|
||||
# Pretend we already set up a config entry.
|
||||
hass.config.components.add(ollama.DOMAIN)
|
||||
MockConfigEntry(
|
||||
domain=ollama.DOMAIN,
|
||||
state=config_entries.ConfigEntryState.LOADED,
|
||||
).add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["errors"] is None
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||
# test model is already "downloaded"
|
||||
return_value={"models": [{"model": TEST_MODEL}]},
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.ollama.async_setup_entry",
|
||||
return_value=True,
|
||||
) as mock_setup_entry,
|
||||
):
|
||||
# Step 1: URL
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Step 2: model
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result3["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result3["data"] == {
|
||||
ollama.CONF_URL: "http://localhost:11434",
|
||||
ollama.CONF_MODEL: TEST_MODEL,
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_form_need_download(hass: HomeAssistant) -> None:
|
||||
"""Test flow when a model needs to be downloaded."""
|
||||
# Pretend we already set up a config entry.
|
||||
hass.config.components.add(ollama.DOMAIN)
|
||||
MockConfigEntry(
|
||||
domain=ollama.DOMAIN,
|
||||
state=config_entries.ConfigEntryState.LOADED,
|
||||
).add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["errors"] is None
|
||||
|
||||
pull_ready = asyncio.Event()
|
||||
pull_called = asyncio.Event()
|
||||
pull_model: str | None = None
|
||||
|
||||
async def pull(self, model: str, *args, **kwargs) -> None:
|
||||
nonlocal pull_model
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await pull_ready.wait()
|
||||
|
||||
pull_model = model
|
||||
pull_called.set()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||
# No models are downloaded
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull",
|
||||
pull,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.ollama.async_setup_entry",
|
||||
return_value=True,
|
||||
) as mock_setup_entry,
|
||||
):
|
||||
# Step 1: URL
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Step 2: model
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Step 3: download
|
||||
assert result3["type"] == FlowResultType.SHOW_PROGRESS
|
||||
result4 = await hass.config_entries.flow.async_configure(
|
||||
result3["flow_id"],
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Run again without the task finishing.
|
||||
# We should still be downloading.
|
||||
assert result4["type"] == FlowResultType.SHOW_PROGRESS
|
||||
result4 = await hass.config_entries.flow.async_configure(
|
||||
result4["flow_id"],
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert result4["type"] == FlowResultType.SHOW_PROGRESS
|
||||
|
||||
# Signal fake pull method to complete
|
||||
pull_ready.set()
|
||||
async with asyncio.timeout(1):
|
||||
await pull_called.wait()
|
||||
|
||||
assert pull_model == TEST_MODEL
|
||||
|
||||
# Step 4: finish
|
||||
result5 = await hass.config_entries.flow.async_configure(
|
||||
result4["flow_id"],
|
||||
)
|
||||
|
||||
assert result5["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result5["data"] == {
|
||||
ollama.CONF_URL: "http://localhost:11434",
|
||||
ollama.CONF_MODEL: TEST_MODEL,
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_options(
|
||||
hass: HomeAssistant, mock_config_entry, mock_init_component
|
||||
) -> None:
|
||||
"""Test the options form."""
|
||||
options_flow = await hass.config_entries.options.async_init(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
options = await hass.config_entries.options.async_configure(
|
||||
options_flow["flow_id"],
|
||||
{ollama.CONF_PROMPT: "test prompt", ollama.CONF_MAX_HISTORY: 100},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
assert options["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert options["data"] == {
|
||||
ollama.CONF_PROMPT: "test prompt",
|
||||
ollama.CONF_MAX_HISTORY: 100,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
(ConnectError(message=""), "cannot_connect"),
|
||||
(RuntimeError(), "unknown"),
|
||||
],
|
||||
)
|
||||
async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
|
||||
"""Test we handle errors."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||
side_effect=side_effect,
|
||||
):
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
|
||||
)
|
||||
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
assert result2["errors"] == {"base": error}
|
||||
|
||||
|
||||
async def test_download_error(hass: HomeAssistant) -> None:
|
||||
"""Test we handle errors while downloading a model."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.pull",
|
||||
side_effect=RuntimeError(),
|
||||
),
|
||||
):
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
result3 = await hass.config_entries.flow.async_configure(
|
||||
result2["flow_id"], {ollama.CONF_MODEL: TEST_MODEL}
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result3["type"] == FlowResultType.SHOW_PROGRESS
|
||||
result4 = await hass.config_entries.flow.async_configure(result3["flow_id"])
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result4["type"] == FlowResultType.ABORT
|
||||
assert result4["reason"] == "download_failed"
|
366
tests/components/ollama/test_init.py
Normal file
366
tests/components/ollama/test_init.py
Normal file
@ -0,0 +1,366 @@
|
||||
"""Tests for the Ollama integration."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from httpx import ConnectError
|
||||
from ollama import Message, ResponseError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import conversation, ollama
|
||||
from homeassistant.components.homeassistant.exposed_entities import async_expose_entity
|
||||
from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import (
|
||||
area_registry as ar,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
intent,
|
||||
)
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_chat(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
area_registry: ar.AreaRegistry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test that the chat function is called with the appropriate arguments."""
|
||||
|
||||
# Create some areas, devices, and entities
|
||||
area_kitchen = area_registry.async_get_or_create("kitchen_id")
|
||||
area_kitchen = area_registry.async_update(area_kitchen.id, name="kitchen")
|
||||
area_bedroom = area_registry.async_get_or_create("bedroom_id")
|
||||
area_bedroom = area_registry.async_update(area_bedroom.id, name="bedroom")
|
||||
area_office = area_registry.async_get_or_create("office_id")
|
||||
area_office = area_registry.async_update(area_office.id, name="office")
|
||||
|
||||
entry = MockConfigEntry()
|
||||
entry.add_to_hass(hass)
|
||||
kitchen_device = device_registry.async_get_or_create(
|
||||
config_entry_id=entry.entry_id,
|
||||
connections=set(),
|
||||
identifiers={("demo", "id-1234")},
|
||||
)
|
||||
device_registry.async_update_device(kitchen_device.id, area_id=area_kitchen.id)
|
||||
|
||||
kitchen_light = entity_registry.async_get_or_create("light", "demo", "1234")
|
||||
kitchen_light = entity_registry.async_update_entity(
|
||||
kitchen_light.entity_id, device_id=kitchen_device.id
|
||||
)
|
||||
hass.states.async_set(
|
||||
kitchen_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
|
||||
)
|
||||
|
||||
bedroom_light = entity_registry.async_get_or_create("light", "demo", "5678")
|
||||
bedroom_light = entity_registry.async_update_entity(
|
||||
bedroom_light.entity_id, area_id=area_bedroom.id
|
||||
)
|
||||
hass.states.async_set(
|
||||
bedroom_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "bedroom light"}
|
||||
)
|
||||
|
||||
# Hide the office light
|
||||
office_light = entity_registry.async_get_or_create("light", "demo", "ABCD")
|
||||
office_light = entity_registry.async_update_entity(
|
||||
office_light.entity_id, area_id=area_office.id
|
||||
)
|
||||
hass.states.async_set(
|
||||
office_light.entity_id, "on", attributes={ATTR_FRIENDLY_NAME: "office light"}
|
||||
)
|
||||
async_expose_entity(hass, conversation.DOMAIN, office_light.entity_id, False)
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
) as mock_chat:
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"test message",
|
||||
None,
|
||||
Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
|
||||
assert mock_chat.call_count == 1
|
||||
args = mock_chat.call_args.kwargs
|
||||
prompt = args["messages"][0]["content"]
|
||||
|
||||
assert args["model"] == "test model"
|
||||
assert args["messages"] == [
|
||||
Message({"role": "system", "content": prompt}),
|
||||
Message({"role": "user", "content": "test message"}),
|
||||
]
|
||||
|
||||
# Verify only exposed devices/areas are in prompt
|
||||
assert "kitchen light" in prompt
|
||||
assert "bedroom light" in prompt
|
||||
assert "office light" not in prompt
|
||||
assert "office" not in prompt
|
||||
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
assert result.response.speech["plain"]["speech"] == "test response"
|
||||
|
||||
|
||||
async def test_message_history_trimming(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test that a single message history is trimmed according to the config."""
|
||||
response_idx = 0
|
||||
|
||||
def response(*args, **kwargs) -> dict:
|
||||
nonlocal response_idx
|
||||
response_idx += 1
|
||||
return {"message": {"role": "assistant", "content": f"response {response_idx}"}}
|
||||
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
side_effect=response,
|
||||
) as mock_chat:
|
||||
# mock_init_component sets "max_history" to 2
|
||||
for i in range(5):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
f"message {i+1}",
|
||||
conversation_id="1234",
|
||||
context=Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
|
||||
assert mock_chat.call_count == 5
|
||||
args = mock_chat.call_args_list
|
||||
prompt = args[0].kwargs["messages"][0]["content"]
|
||||
|
||||
# system + user-1
|
||||
assert len(args[0].kwargs["messages"]) == 2
|
||||
assert args[0].kwargs["messages"][1]["content"] == "message 1"
|
||||
|
||||
# Full history
|
||||
# system + user-1 + assistant-1 + user-2
|
||||
assert len(args[1].kwargs["messages"]) == 4
|
||||
assert args[1].kwargs["messages"][0]["role"] == "system"
|
||||
assert args[1].kwargs["messages"][0]["content"] == prompt
|
||||
assert args[1].kwargs["messages"][1]["role"] == "user"
|
||||
assert args[1].kwargs["messages"][1]["content"] == "message 1"
|
||||
assert args[1].kwargs["messages"][2]["role"] == "assistant"
|
||||
assert args[1].kwargs["messages"][2]["content"] == "response 1"
|
||||
assert args[1].kwargs["messages"][3]["role"] == "user"
|
||||
assert args[1].kwargs["messages"][3]["content"] == "message 2"
|
||||
|
||||
# Full history
|
||||
# system + user-1 + assistant-1 + user-2 + assistant-2 + user-3
|
||||
assert len(args[2].kwargs["messages"]) == 6
|
||||
assert args[2].kwargs["messages"][0]["role"] == "system"
|
||||
assert args[2].kwargs["messages"][0]["content"] == prompt
|
||||
assert args[2].kwargs["messages"][1]["role"] == "user"
|
||||
assert args[2].kwargs["messages"][1]["content"] == "message 1"
|
||||
assert args[2].kwargs["messages"][2]["role"] == "assistant"
|
||||
assert args[2].kwargs["messages"][2]["content"] == "response 1"
|
||||
assert args[2].kwargs["messages"][3]["role"] == "user"
|
||||
assert args[2].kwargs["messages"][3]["content"] == "message 2"
|
||||
assert args[2].kwargs["messages"][4]["role"] == "assistant"
|
||||
assert args[2].kwargs["messages"][4]["content"] == "response 2"
|
||||
assert args[2].kwargs["messages"][5]["role"] == "user"
|
||||
assert args[2].kwargs["messages"][5]["content"] == "message 3"
|
||||
|
||||
# Trimmed down to two user messages.
|
||||
# system + user-2 + assistant-2 + user-3 + assistant-3 + user-4
|
||||
assert len(args[3].kwargs["messages"]) == 6
|
||||
assert args[3].kwargs["messages"][0]["role"] == "system"
|
||||
assert args[3].kwargs["messages"][0]["content"] == prompt
|
||||
assert args[3].kwargs["messages"][1]["role"] == "user"
|
||||
assert args[3].kwargs["messages"][1]["content"] == "message 2"
|
||||
assert args[3].kwargs["messages"][2]["role"] == "assistant"
|
||||
assert args[3].kwargs["messages"][2]["content"] == "response 2"
|
||||
assert args[3].kwargs["messages"][3]["role"] == "user"
|
||||
assert args[3].kwargs["messages"][3]["content"] == "message 3"
|
||||
assert args[3].kwargs["messages"][4]["role"] == "assistant"
|
||||
assert args[3].kwargs["messages"][4]["content"] == "response 3"
|
||||
assert args[3].kwargs["messages"][5]["role"] == "user"
|
||||
assert args[3].kwargs["messages"][5]["content"] == "message 4"
|
||||
|
||||
# Trimmed down to two user messages.
|
||||
# system + user-3 + assistant-3 + user-4 + assistant-4 + user-5
|
||||
assert len(args[3].kwargs["messages"]) == 6
|
||||
assert args[4].kwargs["messages"][0]["role"] == "system"
|
||||
assert args[4].kwargs["messages"][0]["content"] == prompt
|
||||
assert args[4].kwargs["messages"][1]["role"] == "user"
|
||||
assert args[4].kwargs["messages"][1]["content"] == "message 3"
|
||||
assert args[4].kwargs["messages"][2]["role"] == "assistant"
|
||||
assert args[4].kwargs["messages"][2]["content"] == "response 3"
|
||||
assert args[4].kwargs["messages"][3]["role"] == "user"
|
||||
assert args[4].kwargs["messages"][3]["content"] == "message 4"
|
||||
assert args[4].kwargs["messages"][4]["role"] == "assistant"
|
||||
assert args[4].kwargs["messages"][4]["content"] == "response 4"
|
||||
assert args[4].kwargs["messages"][5]["role"] == "user"
|
||||
assert args[4].kwargs["messages"][5]["content"] == "message 5"
|
||||
|
||||
|
||||
async def test_message_history_pruning(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test that old message histories are pruned."""
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
):
|
||||
# Create 3 different message histories
|
||||
conversation_ids: list[str] = []
|
||||
for i in range(3):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
f"message {i+1}",
|
||||
conversation_id=None,
|
||||
context=Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
assert isinstance(result.conversation_id, str)
|
||||
conversation_ids.append(result.conversation_id)
|
||||
|
||||
agent = await conversation._get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
assert isinstance(agent, ollama.OllamaAgent)
|
||||
assert len(agent._history) == 3
|
||||
assert agent._history.keys() == set(conversation_ids)
|
||||
|
||||
# Modify the timestamps of the first 2 histories so they will be pruned
|
||||
# on the next cycle.
|
||||
for conversation_id in conversation_ids[:2]:
|
||||
# Move back 2 hours
|
||||
agent._history[conversation_id].timestamp -= 2 * 60 * 60
|
||||
|
||||
# Next cycle
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"test message",
|
||||
conversation_id=None,
|
||||
context=Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
|
||||
# Only the most recent histories should remain
|
||||
assert len(agent._history) == 2
|
||||
assert conversation_ids[-1] in agent._history
|
||||
assert result.conversation_id in agent._history
|
||||
|
||||
|
||||
async def test_message_history_unlimited(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test that message history is not trimmed when max_history = 0."""
|
||||
conversation_id = "1234"
|
||||
with (
|
||||
patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||
),
|
||||
patch.object(mock_config_entry, "options", {ollama.CONF_MAX_HISTORY: 0}),
|
||||
):
|
||||
for i in range(100):
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
f"message {i+1}",
|
||||
conversation_id=conversation_id,
|
||||
context=Context(),
|
||||
agent_id=mock_config_entry.entry_id,
|
||||
)
|
||||
assert (
|
||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||
), result
|
||||
|
||||
agent = await conversation._get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
assert isinstance(agent, ollama.OllamaAgent)
|
||||
|
||||
assert len(agent._history) == 1
|
||||
assert conversation_id in agent._history
|
||||
assert agent._history[conversation_id].num_user_messages == 100
|
||||
|
||||
|
||||
async def test_error_handling(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||
) -> None:
|
||||
"""Test error handling during converse."""
|
||||
with patch(
|
||||
"ollama.AsyncClient.chat",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ResponseError("test error"),
|
||||
):
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
|
||||
|
||||
async def test_template_error(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test that template error handling works."""
|
||||
hass.config_entries.async_update_entry(
|
||||
mock_config_entry,
|
||||
options={
|
||||
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
|
||||
},
|
||||
)
|
||||
with patch(
|
||||
"ollama.AsyncClient.list",
|
||||
):
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
result = await conversation.async_converse(
|
||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
|
||||
|
||||
async def test_conversation_agent(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
) -> None:
|
||||
"""Test OllamaAgent."""
|
||||
agent = await conversation._get_agent_manager(hass).async_get_agent(
|
||||
mock_config_entry.entry_id
|
||||
)
|
||||
assert agent.supported_languages == MATCH_ALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
(ConnectError(message="Connect error"), "Connect error"),
|
||||
(RuntimeError("Runtime error"), "Runtime error"),
|
||||
],
|
||||
)
|
||||
async def test_init_error(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, caplog, side_effect, error
|
||||
) -> None:
|
||||
"""Test initialization errors."""
|
||||
with patch(
|
||||
"ollama.AsyncClient.list",
|
||||
side_effect=side_effect,
|
||||
):
|
||||
assert await async_setup_component(hass, ollama.DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
assert error in caplog.text
|
Reference in New Issue
Block a user