Update ollama to use the ChatLog/ChatSession APIs (#138167)

* Update ollama to use the ChatLog/ChatSession APIs

* Add documentation about history trimming.

* Revert changes to chat_log.py

* Explicitly check for SystemContent when converting system messages

* Remove half of a comment
This commit is contained in:
Allen Porter
2025-02-09 13:52:01 -08:00
committed by GitHub
parent 379bf10675
commit 57ab567d08
4 changed files with 128 additions and 223 deletions

View File

@ -5,22 +5,18 @@ from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
import json import json
import logging import logging
import time
from typing import Any, Literal from typing import Any, Literal
import ollama import ollama
import voluptuous as vol
from voluptuous_openapi import convert from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, TemplateError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm, template from homeassistant.helpers import chat_session, intent, llm
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid as ulid_util
from .const import ( from .const import (
CONF_KEEP_ALIVE, CONF_KEEP_ALIVE,
@ -32,7 +28,6 @@ from .const import (
DEFAULT_MAX_HISTORY, DEFAULT_MAX_HISTORY,
DEFAULT_NUM_CTX, DEFAULT_NUM_CTX,
DOMAIN, DOMAIN,
MAX_HISTORY_SECONDS,
) )
from .models import MessageHistory, MessageRole from .models import MessageHistory, MessageRole
@ -93,6 +88,44 @@ def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]:
return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v} return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
def _convert_content(
chat_content: conversation.Content
| conversation.ToolResultContent
| conversation.AssistantContent,
) -> ollama.Message:
"""Create tool response content."""
if isinstance(chat_content, conversation.ToolResultContent):
return ollama.Message(
role=MessageRole.TOOL.value,
content=json.dumps(chat_content.tool_result),
)
if isinstance(chat_content, conversation.AssistantContent):
return ollama.Message(
role=MessageRole.ASSISTANT.value,
content=chat_content.content,
tool_calls=[
ollama.Message.ToolCall(
function=ollama.Message.ToolCall.Function(
name=tool_call.tool_name,
arguments=tool_call.tool_args,
)
)
for tool_call in chat_content.tool_calls or ()
],
)
if isinstance(chat_content, conversation.UserContent):
return ollama.Message(
role=MessageRole.USER.value,
content=chat_content.content,
)
if isinstance(chat_content, conversation.SystemContent):
return ollama.Message(
role=MessageRole.SYSTEM.value,
content=chat_content.content,
)
raise ValueError(f"Unexpected content type: {type(chat_content)}")
class OllamaConversationEntity( class OllamaConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent conversation.ConversationEntity, conversation.AbstractConversationAgent
): ):
@ -105,7 +138,6 @@ class OllamaConversationEntity(
self.entry = entry self.entry = entry
# conversation id -> message history # conversation id -> message history
self._history: dict[str, MessageHistory] = {}
self._attr_name = entry.title self._attr_name = entry.title
self._attr_unique_id = entry.entry_id self._attr_unique_id = entry.entry_id
if self.entry.options.get(CONF_LLM_HASS_API): if self.entry.options.get(CONF_LLM_HASS_API):
@ -138,121 +170,48 @@ class OllamaConversationEntity(
self, user_input: conversation.ConversationInput self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Process a sentence.""" """Process a sentence."""
with (
chat_session.async_get_chat_session(
self.hass, user_input.conversation_id
) as session,
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
):
return await self._async_handle_message(user_input, chat_log)
async def _async_handle_message(
self,
user_input: conversation.ConversationInput,
chat_log: conversation.ChatLog,
) -> conversation.ConversationResult:
"""Call the API."""
settings = {**self.entry.data, **self.entry.options} settings = {**self.entry.data, **self.entry.options}
client = self.hass.data[DOMAIN][self.entry.entry_id] client = self.hass.data[DOMAIN][self.entry.entry_id]
conversation_id = user_input.conversation_id or ulid_util.ulid_now()
model = settings[CONF_MODEL] model = settings[CONF_MODEL]
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.APIInstance | None = None
tools: list[dict[str, Any]] | None = None
user_name: str | None = None
llm_context = llm.LLMContext(
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
if settings.get(CONF_LLM_HASS_API): try:
try: await chat_log.async_update_llm_data(
llm_api = await llm.async_get_api( DOMAIN,
self.hass, user_input,
settings[CONF_LLM_HASS_API], settings.get(CONF_LLM_HASS_API),
llm_context, settings.get(CONF_PROMPT),
) )
except HomeAssistantError as err: except conversation.ConverseError as err:
_LOGGER.error("Error getting LLM API: %s", err) return err.as_conversation_result()
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN, tools: list[dict[str, Any]] | None = None
f"Error preparing LLM API: {err}", if chat_log.llm_api:
)
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [ tools = [
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools _format_tool(tool, chat_log.llm_api.custom_serializer)
for tool in chat_log.llm_api.tools
] ]
if ( message_history: MessageHistory = MessageHistory(
user_input.context [_convert_content(content) for content in chat_log.content]
and user_input.context.user_id )
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
)
):
user_name = user.name
# Look up message history
message_history: MessageHistory | None = None
message_history = self._history.get(conversation_id)
if message_history is None:
# New history
#
# Render prompt and error out early if there's a problem
try:
prompt_parts = [
template.Template(
llm.BASE_PROMPT
+ settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
self.hass,
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
)
]
except TemplateError as err:
_LOGGER.error("Error rendering prompt: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem generating my prompt: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
if llm_api:
prompt_parts.append(llm_api.api_prompt)
prompt = "\n".join(prompt_parts)
_LOGGER.debug("Prompt: %s", prompt)
_LOGGER.debug("Tools: %s", tools)
message_history = MessageHistory(
timestamp=time.monotonic(),
messages=[
ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
],
)
self._history[conversation_id] = message_history
else:
# Bump timestamp so this conversation won't get cleaned up
message_history.timestamp = time.monotonic()
# Clean up old histories
self._prune_old_histories()
# Trim this message history to keep a maximum number of *user* messages
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)) max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
self._trim_history(message_history, max_messages) self._trim_history(message_history, max_messages)
# Add new user message
message_history.messages.append(
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{"messages": message_history.messages},
)
# Get response # Get response
# To prevent infinite loops, we limit the number of iterations # To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS): for _iteration in range(MAX_TOOL_ITERATIONS):
@ -269,77 +228,75 @@ class OllamaConversationEntity(
) )
except (ollama.RequestError, ollama.ResponseError) as err: except (ollama.RequestError, ollama.ResponseError) as err:
_LOGGER.error("Unexpected error talking to Ollama server: %s", err) _LOGGER.error("Unexpected error talking to Ollama server: %s", err)
intent_response.async_set_error( raise HomeAssistantError(
intent.IntentResponseErrorCode.UNKNOWN, f"Sorry, I had a problem talking to the Ollama server: {err}"
f"Sorry, I had a problem talking to the Ollama server: {err}", ) from err
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
response_message = response["message"] response_message = response["message"]
content = response_message.get("content")
tool_calls = response_message.get("tool_calls")
message_history.messages.append( message_history.messages.append(
ollama.Message( ollama.Message(
role=response_message["role"], role=response_message["role"],
content=response_message.get("content"), content=content,
tool_calls=response_message.get("tool_calls"), tool_calls=tool_calls,
) )
) )
tool_inputs = [
tool_calls = response_message.get("tool_calls") llm.ToolInput(
if not tool_calls or not llm_api:
break
for tool_call in tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call["function"]["name"], tool_name=tool_call["function"]["name"],
tool_args=_parse_tool_args(tool_call["function"]["arguments"]), tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
) )
_LOGGER.debug( for tool_call in tool_calls or ()
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args ]
)
try: message_history.messages.extend(
tool_response = await llm_api.async_call_tool(tool_input) [
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
_LOGGER.debug("Tool response: %s", tool_response)
message_history.messages.append(
ollama.Message( ollama.Message(
role=MessageRole.TOOL.value, role=MessageRole.TOOL.value,
content=json.dumps(tool_response), content=json.dumps(tool_response.tool_result),
) )
) async for tool_response in chat_log.async_add_assistant_content(
conversation.AssistantContent(
agent_id=user_input.agent_id,
content=content,
tool_calls=tool_inputs or None,
)
)
]
)
if not tool_calls:
break
# Create intent response # Create intent response
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response_message["content"]) intent_response.async_set_speech(response_message["content"])
return conversation.ConversationResult( return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id response=intent_response, conversation_id=chat_log.conversation_id
) )
def _prune_old_histories(self) -> None:
"""Remove old message histories."""
now = time.monotonic()
self._history = {
conversation_id: message_history
for conversation_id, message_history in self._history.items()
if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
}
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None: def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
"""Trims excess messages from a single history.""" """Trims excess messages from a single history.
This sets the max history to allow a configurable size history may take
up in the context window.
Note that some messages in the history may not be from ollama only, and
may come from other anents, so the assumptions here may not strictly hold,
but generally should be effective.
"""
if max_messages < 1: if max_messages < 1:
# Keep all messages # Keep all messages
return return
if message_history.num_user_messages >= max_messages: # Ignore the in progress user message
num_previous_rounds = message_history.num_user_messages - 1
if num_previous_rounds >= max_messages:
# Trim history but keep system prompt (first message). # Trim history but keep system prompt (first message).
# Every other message should be an assistant message, so keep 2x # Every other message should be an assistant message, so keep 2x
# message objects. # message objects. Also keep the last in progress user message
num_keep = 2 * max_messages num_keep = 2 * max_messages + 1
drop_index = len(message_history.messages) - num_keep drop_index = len(message_history.messages) - num_keep
message_history.messages = [ message_history.messages = [
message_history.messages[0] message_history.messages[0]

View File

@ -19,9 +19,6 @@ class MessageRole(StrEnum):
class MessageHistory: class MessageHistory:
"""Chat message history.""" """Chat message history."""
timestamp: float
"""Timestamp of last use in seconds."""
messages: list[ollama.Message] messages: list[ollama.Message]
"""List of message history, including system prompt and assistant responses.""" """List of message history, including system prompt and assistant responses."""

View File

@ -1,7 +1,7 @@
# serializer version: 1 # serializer version: 1
# name: test_unknown_hass_api # name: test_unknown_hass_api
dict({ dict({
'conversation_id': None, 'conversation_id': '1234',
'response': IntentResponse( 'response': IntentResponse(
card=dict({ card=dict({
}), }),
@ -20,7 +20,7 @@
speech=dict({ speech=dict({
'plain': dict({ 'plain': dict({
'extra_data': None, 'extra_data': None,
'speech': 'Error preparing LLM API: API non-existing not found', 'speech': 'Error preparing LLM API',
}), }),
}), }),
speech_slots=dict({ speech_slots=dict({

View File

@ -325,7 +325,11 @@ async def test_unknown_hass_api(
await hass.async_block_till_done() await hass.async_block_till_done()
result = await conversation.async_converse( result = await conversation.async_converse(
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id hass,
"hello",
"1234",
Context(),
agent_id=mock_config_entry.entry_id,
) )
assert result == snapshot assert result == snapshot
@ -428,70 +432,17 @@ async def test_message_history_trimming(
assert args[4].kwargs["messages"][5]["content"] == "message 5" assert args[4].kwargs["messages"][5]["content"] == "message 5"
async def test_message_history_pruning(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None:
"""Test that old message histories are pruned."""
with patch(
"ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}},
):
# Create 3 different message histories
conversation_ids: list[str] = []
for i in range(3):
result = await conversation.async_converse(
hass,
f"message {i + 1}",
conversation_id=None,
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert (
result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result
assert isinstance(result.conversation_id, str)
conversation_ids.append(result.conversation_id)
agent = conversation.get_agent_manager(hass).async_get_agent(
mock_config_entry.entry_id
)
assert len(agent._history) == 3
assert agent._history.keys() == set(conversation_ids)
# Modify the timestamps of the first 2 histories so they will be pruned
# on the next cycle.
for conversation_id in conversation_ids[:2]:
# Move back 2 hours
agent._history[conversation_id].timestamp -= 2 * 60 * 60
# Next cycle
result = await conversation.async_converse(
hass,
"test message",
conversation_id=None,
context=Context(),
agent_id=mock_config_entry.entry_id,
)
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, (
result
)
# Only the most recent histories should remain
assert len(agent._history) == 2
assert conversation_ids[-1] in agent._history
assert result.conversation_id in agent._history
async def test_message_history_unlimited( async def test_message_history_unlimited(
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
) -> None: ) -> None:
"""Test that message history is not trimmed when max_history = 0.""" """Test that message history is not trimmed when max_history = 0."""
conversation_id = "1234" conversation_id = "1234"
with ( with (
patch( patch(
"ollama.AsyncClient.chat", "ollama.AsyncClient.chat",
return_value={"message": {"role": "assistant", "content": "test response"}}, return_value={"message": {"role": "assistant", "content": "test response"}},
), ) as mock_chat,
): ):
hass.config_entries.async_update_entry( hass.config_entries.async_update_entry(
mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0} mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0}
@ -508,13 +459,13 @@ async def test_message_history_unlimited(
result.response.response_type == intent.IntentResponseType.ACTION_DONE result.response.response_type == intent.IntentResponseType.ACTION_DONE
), result ), result
agent = conversation.get_agent_manager(hass).async_get_agent( args = mock_chat.call_args_list
mock_config_entry.entry_id assert len(args) == 100
recorded_messages = args[-1].kwargs["messages"]
message_count = sum(
(message["role"] == "user") for message in recorded_messages
) )
assert message_count == 100
assert len(agent._history) == 1
assert conversation_id in agent._history
assert agent._history[conversation_id].num_user_messages == 100
async def test_error_handling( async def test_error_handling(