Update ollama to use the ChatLog/ChatSession APIs (#138167)

* Update ollama to use the ChatLog/ChatSession APIs

* Add documentation about history trimming.

* Revert changes to chat_log.py

* Explicitly check for SystemContent when converting system messages

* Remove half of a comment
This commit is contained in:
Allen Porter
2025-02-09 13:52:01 -08:00
committed by GitHub
parent 379bf10675
commit 57ab567d08
4 changed files with 128 additions and 223 deletions

View File

@ -5,22 +5,18 @@ from __future__ import annotations
from collections.abc import Callable
import json
import logging
import time
from typing import Any, Literal
import ollama
import voluptuous as vol
from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import intent, llm, template
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, intent, llm
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid as ulid_util
from .const import (
CONF_KEEP_ALIVE,
@ -32,7 +28,6 @@ from .const import (
DEFAULT_MAX_HISTORY,
DEFAULT_NUM_CTX,
DOMAIN,
MAX_HISTORY_SECONDS,
)
from .models import MessageHistory, MessageRole
@ -93,6 +88,44 @@ def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]:
return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
def _convert_content(
chat_content: conversation.Content
| conversation.ToolResultContent
| conversation.AssistantContent,
) -> ollama.Message:
"""Create tool response content."""
if isinstance(chat_content, conversation.ToolResultContent):
return ollama.Message(
role=MessageRole.TOOL.value,
content=json.dumps(chat_content.tool_result),
)
if isinstance(chat_content, conversation.AssistantContent):
return ollama.Message(
role=MessageRole.ASSISTANT.value,
content=chat_content.content,
tool_calls=[
ollama.Message.ToolCall(
function=ollama.Message.ToolCall.Function(
name=tool_call.tool_name,
arguments=tool_call.tool_args,
)
)
for tool_call in chat_content.tool_calls or ()
],
)
if isinstance(chat_content, conversation.UserContent):
return ollama.Message(
role=MessageRole.USER.value,
content=chat_content.content,
)
if isinstance(chat_content, conversation.SystemContent):
return ollama.Message(
role=MessageRole.SYSTEM.value,
content=chat_content.content,
)
raise ValueError(f"Unexpected content type: {type(chat_content)}")
class OllamaConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
@ -105,7 +138,6 @@ class OllamaConversationEntity(
self.entry = entry
# conversation id -> message history
self._history: dict[str, MessageHistory] = {}
self._attr_name = entry.title
self._attr_unique_id = entry.entry_id
if self.entry.options.get(CONF_LLM_HASS_API):
@ -138,121 +170,48 @@ class OllamaConversationEntity(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
):
return await self._async_handle_message(user_input, chat_log)
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
chat_log: conversation.ChatLog,
) -> conversation.ConversationResult:
"""Call the API."""
settings = {**self.entry.data, **self.entry.options}
client = self.hass.data[DOMAIN][self.entry.entry_id]
conversation_id = user_input.conversation_id or ulid_util.ulid_now()
model = settings[CONF_MODEL]
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.APIInstance | None = None
tools: list[dict[str, Any]] | None = None
user_name: str | None = None
llm_context = llm.LLMContext(
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
if settings.get(CONF_LLM_HASS_API):
try:
llm_api = await llm.async_get_api(
self.hass,
settings[CONF_LLM_HASS_API],
llm_context,
)
except HomeAssistantError as err:
_LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
try:
await chat_log.async_update_llm_data(
DOMAIN,
user_input,
settings.get(CONF_LLM_HASS_API),
settings.get(CONF_PROMPT),
)
except conversation.ConverseError as err:
return err.as_conversation_result()
tools: list[dict[str, Any]] | None = None
if chat_log.llm_api:
tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
_format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
]
if (
user_input.context
and user_input.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
)
):
user_name = user.name
# Look up message history
message_history: MessageHistory | None = None
message_history = self._history.get(conversation_id)
if message_history is None:
# New history
#
# Render prompt and error out early if there's a problem
try:
prompt_parts = [
template.Template(
llm.BASE_PROMPT
+ settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
self.hass,
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
)
]
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem generating my prompt: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
if llm_api:
prompt_parts.append(llm_api.api_prompt)
prompt = "\n".join(prompt_parts)
_LOGGER.debug("Prompt: %s", prompt)
_LOGGER.debug("Tools: %s", tools)
message_history = MessageHistory(
timestamp=time.monotonic(),
messages=[
ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
],
)
self._history[conversation_id] = message_history
else:
# Bump timestamp so this conversation won't get cleaned up
message_history.timestamp = time.monotonic()
# Clean up old histories
self._prune_old_histories()
# Trim this message history to keep a maximum number of *user* messages
message_history: MessageHistory = MessageHistory(
[_convert_content(content) for content in chat_log.content]
)
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
self._trim_history(message_history, max_messages)
# Add new user message
message_history.messages.append(
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{"messages": message_history.messages},
)
# Get response
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
@ -269,77 +228,75 @@ class OllamaConversationEntity(
)
except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to the Ollama server: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
raise HomeAssistantError(
f"Sorry, I had a problem talking to the Ollama server: {err}"
) from err
response_message = response["message"]
content = response_message.get("content")
tool_calls = response_message.get("tool_calls")
message_history.messages.append(
ollama.Message(
role=response_message["role"],
content=response_message.get("content"),
tool_calls=response_message.get("tool_calls"),
content=content,
tool_calls=tool_calls,
)
)
tool_calls = response_message.get("tool_calls")
if not tool_calls or not llm_api:
break
for tool_call in tool_calls:
tool_input = llm.ToolInput(
tool_inputs = [
llm.ToolInput(
tool_name=tool_call["function"]["name"],
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
)
_LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
for tool_call in tool_calls or ()
]
try:
tool_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
_LOGGER.debug("Tool response: %s", tool_response)
message_history.messages.append(
message_history.messages.extend(
[
ollama.Message(
role=MessageRole.TOOL.value,
content=json.dumps(tool_response),
content=json.dumps(tool_response.tool_result),
)
)
async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=content,
tool_calls=tool_inputs or None,
)
)
]
)
if not tool_calls:
break
# Create intent response
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response_message["content"])
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
response=intent_response, conversation_id=chat_log.conversation_id
)
def _prune_old_histories(self) -> None:
"""Remove old message histories."""
now = time.monotonic()
self._history = {
conversation_id: message_history
for conversation_id, message_history in self._history.items()
if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
}
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
"""Trims excess messages from a single history."""
"""Trims excess messages from a single history.
This sets the max history to allow a configurable size history may take
up in the context window.
Note that some messages in the history may not be from ollama only, and
may come from other anents, so the assumptions here may not strictly hold,
but generally should be effective.
"""
if max_messages < 1:
# Keep all messages
return
if message_history.num_user_messages >= max_messages:
# Ignore the in progress user message
num_previous_rounds = message_history.num_user_messages - 1
if num_previous_rounds >= max_messages:
# Trim history but keep system prompt (first message).
# Every other message should be an assistant message, so keep 2x
# message objects.
num_keep = 2 * max_messages
# message objects. Also keep the last in progress user message
num_keep = 2 * max_messages + 1
drop_index = len(message_history.messages) - num_keep
message_history.messages = [
message_history.messages[0]

View File

@ -19,9 +19,6 @@ class MessageRole(StrEnum):
class MessageHistory:
"""Chat message history."""
timestamp: float
"""Timestamp of last use in seconds."""
messages: list[ollama.Message]
"""List of message history, including system prompt and assistant responses."""

View File

@ -1,7 +1,7 @@
# serializer version: 1
# name: test_unknown_hass_api
dict({
'conversation_id': None,
'conversation_id': '1234',
'response': IntentResponse(
card=dict({
}),
@ -20,7 +20,7 @@
speech=dict({
'plain': dict({
'extra_data': None,
'speech': 'Error preparing LLM API: API non-existing not found',
'speech': 'Error preparing LLM API',
}),
}),
speech_slots=dict({

View File

@ -325,7 +325,11 @@ async def test_unknown_hass_api(
await hass.async_block_till_done()
result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
hass,
"hello",
"1234",
Context(),
agent_id=mock_config_entry.entry_id,
)
assert result == snapshot
@ -428,70 +432,17 @@ async def test_message_history_trimming(
assert args[4].kwargs["messages"][5]["content"] == "message 5"
async def test_message_history_pruning(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that old message histories are pruned."""
with patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
):
# Create 3 different message histories
conversation_ids: list[str] = []
for i in range(3):
result = await conversation.async_converse(
hass,
f"message {i + 1}",
conversation_id=None,
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert isinstance(result.conversation_id, str)
conversation_ids.append(result.conversation_id)
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert len(agent._history) == 3
assert agent._history.keys() == set(conversation_ids)
# Modify the timestamps of the first 2 histories so they will be pruned
# on the next cycle.
for conversation_id in conversation_ids[:2]:
# Move back 2 hours
agent._history[conversation_id].timestamp -= 2 * 60 * 60
# Next cycle
result = await conversation.async_converse(
hass,
"test message",
conversation_id=None,
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, (
result
)
# Only the most recent histories should remain
assert len(agent._history) == 2
assert conversation_ids[-1] in agent._history
assert result.conversation_id in agent._history
async def test_message_history_unlimited(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that message history is not trimmed when max_history = 0."""
conversation_id = "1234"
with (
patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
),
) as mock_chat,
):
hass.config_entries.async_update_entry(
mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0}
@ -508,13 +459,13 @@ async def test_message_history_unlimited(
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
args = mock_chat.call_args_list
assert len(args) == 100
recorded_messages = args[-1].kwargs["messages"]
message_count = sum(
(message["role"] == "user") for message in recorded_messages
)
assert len(agent._history) == 1
assert conversation_id in agent._history
assert agent._history[conversation_id].num_user_messages == 100
assert message_count == 100
async def test_error_handling(