mirror of
https://github.com/home-assistant/core.git
synced 2025-07-30 10:48:01 +02:00
Update ollama to use the ChatLog/ChatSession APIs (#138167)
* Update ollama to use the ChatLog/ChatSession APIs * Add documentation about history trimming. * Revert changes to chat_log.py * Explicitly check for SystemContent when converting system messages * Remove half of a comment
This commit is contained in:
@ -5,22 +5,18 @@ from __future__ import annotations
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import ollama
|
import ollama
|
||||||
import voluptuous as vol
|
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, conversation
|
from homeassistant.components import assist_pipeline, conversation
|
||||||
from homeassistant.components.conversation import trace
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import intent, llm, template
|
from homeassistant.helpers import chat_session, intent, llm
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.util import ulid as ulid_util
|
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_KEEP_ALIVE,
|
CONF_KEEP_ALIVE,
|
||||||
@ -32,7 +28,6 @@ from .const import (
|
|||||||
DEFAULT_MAX_HISTORY,
|
DEFAULT_MAX_HISTORY,
|
||||||
DEFAULT_NUM_CTX,
|
DEFAULT_NUM_CTX,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
MAX_HISTORY_SECONDS,
|
|
||||||
)
|
)
|
||||||
from .models import MessageHistory, MessageRole
|
from .models import MessageHistory, MessageRole
|
||||||
|
|
||||||
@ -93,6 +88,44 @@ def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]:
|
|||||||
return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
|
return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_content(
|
||||||
|
chat_content: conversation.Content
|
||||||
|
| conversation.ToolResultContent
|
||||||
|
| conversation.AssistantContent,
|
||||||
|
) -> ollama.Message:
|
||||||
|
"""Create tool response content."""
|
||||||
|
if isinstance(chat_content, conversation.ToolResultContent):
|
||||||
|
return ollama.Message(
|
||||||
|
role=MessageRole.TOOL.value,
|
||||||
|
content=json.dumps(chat_content.tool_result),
|
||||||
|
)
|
||||||
|
if isinstance(chat_content, conversation.AssistantContent):
|
||||||
|
return ollama.Message(
|
||||||
|
role=MessageRole.ASSISTANT.value,
|
||||||
|
content=chat_content.content,
|
||||||
|
tool_calls=[
|
||||||
|
ollama.Message.ToolCall(
|
||||||
|
function=ollama.Message.ToolCall.Function(
|
||||||
|
name=tool_call.tool_name,
|
||||||
|
arguments=tool_call.tool_args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for tool_call in chat_content.tool_calls or ()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if isinstance(chat_content, conversation.UserContent):
|
||||||
|
return ollama.Message(
|
||||||
|
role=MessageRole.USER.value,
|
||||||
|
content=chat_content.content,
|
||||||
|
)
|
||||||
|
if isinstance(chat_content, conversation.SystemContent):
|
||||||
|
return ollama.Message(
|
||||||
|
role=MessageRole.SYSTEM.value,
|
||||||
|
content=chat_content.content,
|
||||||
|
)
|
||||||
|
raise ValueError(f"Unexpected content type: {type(chat_content)}")
|
||||||
|
|
||||||
|
|
||||||
class OllamaConversationEntity(
|
class OllamaConversationEntity(
|
||||||
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
conversation.ConversationEntity, conversation.AbstractConversationAgent
|
||||||
):
|
):
|
||||||
@ -105,7 +138,6 @@ class OllamaConversationEntity(
|
|||||||
self.entry = entry
|
self.entry = entry
|
||||||
|
|
||||||
# conversation id -> message history
|
# conversation id -> message history
|
||||||
self._history: dict[str, MessageHistory] = {}
|
|
||||||
self._attr_name = entry.title
|
self._attr_name = entry.title
|
||||||
self._attr_unique_id = entry.entry_id
|
self._attr_unique_id = entry.entry_id
|
||||||
if self.entry.options.get(CONF_LLM_HASS_API):
|
if self.entry.options.get(CONF_LLM_HASS_API):
|
||||||
@ -138,121 +170,48 @@ class OllamaConversationEntity(
|
|||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
) -> conversation.ConversationResult:
|
) -> conversation.ConversationResult:
|
||||||
"""Process a sentence."""
|
"""Process a sentence."""
|
||||||
|
with (
|
||||||
|
chat_session.async_get_chat_session(
|
||||||
|
self.hass, user_input.conversation_id
|
||||||
|
) as session,
|
||||||
|
conversation.async_get_chat_log(self.hass, session, user_input) as chat_log,
|
||||||
|
):
|
||||||
|
return await self._async_handle_message(user_input, chat_log)
|
||||||
|
|
||||||
|
async def _async_handle_message(
|
||||||
|
self,
|
||||||
|
user_input: conversation.ConversationInput,
|
||||||
|
chat_log: conversation.ChatLog,
|
||||||
|
) -> conversation.ConversationResult:
|
||||||
|
"""Call the API."""
|
||||||
settings = {**self.entry.data, **self.entry.options}
|
settings = {**self.entry.data, **self.entry.options}
|
||||||
|
|
||||||
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
client = self.hass.data[DOMAIN][self.entry.entry_id]
|
||||||
conversation_id = user_input.conversation_id or ulid_util.ulid_now()
|
|
||||||
model = settings[CONF_MODEL]
|
model = settings[CONF_MODEL]
|
||||||
intent_response = intent.IntentResponse(language=user_input.language)
|
|
||||||
llm_api: llm.APIInstance | None = None
|
|
||||||
tools: list[dict[str, Any]] | None = None
|
|
||||||
user_name: str | None = None
|
|
||||||
llm_context = llm.LLMContext(
|
|
||||||
platform=DOMAIN,
|
|
||||||
context=user_input.context,
|
|
||||||
user_prompt=user_input.text,
|
|
||||||
language=user_input.language,
|
|
||||||
assistant=conversation.DOMAIN,
|
|
||||||
device_id=user_input.device_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if settings.get(CONF_LLM_HASS_API):
|
try:
|
||||||
try:
|
await chat_log.async_update_llm_data(
|
||||||
llm_api = await llm.async_get_api(
|
DOMAIN,
|
||||||
self.hass,
|
user_input,
|
||||||
settings[CONF_LLM_HASS_API],
|
settings.get(CONF_LLM_HASS_API),
|
||||||
llm_context,
|
settings.get(CONF_PROMPT),
|
||||||
)
|
)
|
||||||
except HomeAssistantError as err:
|
except conversation.ConverseError as err:
|
||||||
_LOGGER.error("Error getting LLM API: %s", err)
|
return err.as_conversation_result()
|
||||||
intent_response.async_set_error(
|
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
tools: list[dict[str, Any]] | None = None
|
||||||
f"Error preparing LLM API: {err}",
|
if chat_log.llm_api:
|
||||||
)
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response, conversation_id=user_input.conversation_id
|
|
||||||
)
|
|
||||||
tools = [
|
tools = [
|
||||||
_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools
|
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||||
|
for tool in chat_log.llm_api.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
if (
|
message_history: MessageHistory = MessageHistory(
|
||||||
user_input.context
|
[_convert_content(content) for content in chat_log.content]
|
||||||
and user_input.context.user_id
|
)
|
||||||
and (
|
|
||||||
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
user_name = user.name
|
|
||||||
|
|
||||||
# Look up message history
|
|
||||||
message_history: MessageHistory | None = None
|
|
||||||
message_history = self._history.get(conversation_id)
|
|
||||||
if message_history is None:
|
|
||||||
# New history
|
|
||||||
#
|
|
||||||
# Render prompt and error out early if there's a problem
|
|
||||||
try:
|
|
||||||
prompt_parts = [
|
|
||||||
template.Template(
|
|
||||||
llm.BASE_PROMPT
|
|
||||||
+ settings.get(CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
|
||||||
self.hass,
|
|
||||||
).async_render(
|
|
||||||
{
|
|
||||||
"ha_name": self.hass.config.location_name,
|
|
||||||
"user_name": user_name,
|
|
||||||
"llm_context": llm_context,
|
|
||||||
},
|
|
||||||
parse_result=False,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
except TemplateError as err:
|
|
||||||
_LOGGER.error("Error rendering prompt: %s", err)
|
|
||||||
intent_response.async_set_error(
|
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
|
||||||
f"Sorry, I had a problem generating my prompt: {err}",
|
|
||||||
)
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response, conversation_id=conversation_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if llm_api:
|
|
||||||
prompt_parts.append(llm_api.api_prompt)
|
|
||||||
|
|
||||||
prompt = "\n".join(prompt_parts)
|
|
||||||
_LOGGER.debug("Prompt: %s", prompt)
|
|
||||||
_LOGGER.debug("Tools: %s", tools)
|
|
||||||
|
|
||||||
message_history = MessageHistory(
|
|
||||||
timestamp=time.monotonic(),
|
|
||||||
messages=[
|
|
||||||
ollama.Message(role=MessageRole.SYSTEM.value, content=prompt)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
self._history[conversation_id] = message_history
|
|
||||||
else:
|
|
||||||
# Bump timestamp so this conversation won't get cleaned up
|
|
||||||
message_history.timestamp = time.monotonic()
|
|
||||||
|
|
||||||
# Clean up old histories
|
|
||||||
self._prune_old_histories()
|
|
||||||
|
|
||||||
# Trim this message history to keep a maximum number of *user* messages
|
|
||||||
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
|
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
|
||||||
self._trim_history(message_history, max_messages)
|
self._trim_history(message_history, max_messages)
|
||||||
|
|
||||||
# Add new user message
|
|
||||||
message_history.messages.append(
|
|
||||||
ollama.Message(role=MessageRole.USER.value, content=user_input.text)
|
|
||||||
)
|
|
||||||
|
|
||||||
trace.async_conversation_trace_append(
|
|
||||||
trace.ConversationTraceEventType.AGENT_DETAIL,
|
|
||||||
{"messages": message_history.messages},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get response
|
# Get response
|
||||||
# To prevent infinite loops, we limit the number of iterations
|
# To prevent infinite loops, we limit the number of iterations
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
@ -269,77 +228,75 @@ class OllamaConversationEntity(
|
|||||||
)
|
)
|
||||||
except (ollama.RequestError, ollama.ResponseError) as err:
|
except (ollama.RequestError, ollama.ResponseError) as err:
|
||||||
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
|
||||||
intent_response.async_set_error(
|
raise HomeAssistantError(
|
||||||
intent.IntentResponseErrorCode.UNKNOWN,
|
f"Sorry, I had a problem talking to the Ollama server: {err}"
|
||||||
f"Sorry, I had a problem talking to the Ollama server: {err}",
|
) from err
|
||||||
)
|
|
||||||
return conversation.ConversationResult(
|
|
||||||
response=intent_response, conversation_id=conversation_id
|
|
||||||
)
|
|
||||||
|
|
||||||
response_message = response["message"]
|
response_message = response["message"]
|
||||||
|
content = response_message.get("content")
|
||||||
|
tool_calls = response_message.get("tool_calls")
|
||||||
message_history.messages.append(
|
message_history.messages.append(
|
||||||
ollama.Message(
|
ollama.Message(
|
||||||
role=response_message["role"],
|
role=response_message["role"],
|
||||||
content=response_message.get("content"),
|
content=content,
|
||||||
tool_calls=response_message.get("tool_calls"),
|
tool_calls=tool_calls,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
tool_inputs = [
|
||||||
tool_calls = response_message.get("tool_calls")
|
llm.ToolInput(
|
||||||
if not tool_calls or not llm_api:
|
|
||||||
break
|
|
||||||
|
|
||||||
for tool_call in tool_calls:
|
|
||||||
tool_input = llm.ToolInput(
|
|
||||||
tool_name=tool_call["function"]["name"],
|
tool_name=tool_call["function"]["name"],
|
||||||
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
|
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
|
||||||
)
|
)
|
||||||
_LOGGER.debug(
|
for tool_call in tool_calls or ()
|
||||||
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
|
]
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
message_history.messages.extend(
|
||||||
tool_response = await llm_api.async_call_tool(tool_input)
|
[
|
||||||
except (HomeAssistantError, vol.Invalid) as e:
|
|
||||||
tool_response = {"error": type(e).__name__}
|
|
||||||
if str(e):
|
|
||||||
tool_response["error_text"] = str(e)
|
|
||||||
|
|
||||||
_LOGGER.debug("Tool response: %s", tool_response)
|
|
||||||
message_history.messages.append(
|
|
||||||
ollama.Message(
|
ollama.Message(
|
||||||
role=MessageRole.TOOL.value,
|
role=MessageRole.TOOL.value,
|
||||||
content=json.dumps(tool_response),
|
content=json.dumps(tool_response.tool_result),
|
||||||
)
|
)
|
||||||
)
|
async for tool_response in chat_log.async_add_assistant_content(
|
||||||
|
conversation.AssistantContent(
|
||||||
|
agent_id=user_input.agent_id,
|
||||||
|
content=content,
|
||||||
|
tool_calls=tool_inputs or None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not tool_calls:
|
||||||
|
break
|
||||||
|
|
||||||
# Create intent response
|
# Create intent response
|
||||||
|
intent_response = intent.IntentResponse(language=user_input.language)
|
||||||
intent_response.async_set_speech(response_message["content"])
|
intent_response.async_set_speech(response_message["content"])
|
||||||
return conversation.ConversationResult(
|
return conversation.ConversationResult(
|
||||||
response=intent_response, conversation_id=conversation_id
|
response=intent_response, conversation_id=chat_log.conversation_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prune_old_histories(self) -> None:
|
|
||||||
"""Remove old message histories."""
|
|
||||||
now = time.monotonic()
|
|
||||||
self._history = {
|
|
||||||
conversation_id: message_history
|
|
||||||
for conversation_id, message_history in self._history.items()
|
|
||||||
if (now - message_history.timestamp) <= MAX_HISTORY_SECONDS
|
|
||||||
}
|
|
||||||
|
|
||||||
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
|
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
|
||||||
"""Trims excess messages from a single history."""
|
"""Trims excess messages from a single history.
|
||||||
|
|
||||||
|
This sets the max history to allow a configurable size history may take
|
||||||
|
up in the context window.
|
||||||
|
|
||||||
|
Note that some messages in the history may not be from ollama only, and
|
||||||
|
may come from other anents, so the assumptions here may not strictly hold,
|
||||||
|
but generally should be effective.
|
||||||
|
"""
|
||||||
if max_messages < 1:
|
if max_messages < 1:
|
||||||
# Keep all messages
|
# Keep all messages
|
||||||
return
|
return
|
||||||
|
|
||||||
if message_history.num_user_messages >= max_messages:
|
# Ignore the in progress user message
|
||||||
|
num_previous_rounds = message_history.num_user_messages - 1
|
||||||
|
if num_previous_rounds >= max_messages:
|
||||||
# Trim history but keep system prompt (first message).
|
# Trim history but keep system prompt (first message).
|
||||||
# Every other message should be an assistant message, so keep 2x
|
# Every other message should be an assistant message, so keep 2x
|
||||||
# message objects.
|
# message objects. Also keep the last in progress user message
|
||||||
num_keep = 2 * max_messages
|
num_keep = 2 * max_messages + 1
|
||||||
drop_index = len(message_history.messages) - num_keep
|
drop_index = len(message_history.messages) - num_keep
|
||||||
message_history.messages = [
|
message_history.messages = [
|
||||||
message_history.messages[0]
|
message_history.messages[0]
|
||||||
|
@ -19,9 +19,6 @@ class MessageRole(StrEnum):
|
|||||||
class MessageHistory:
|
class MessageHistory:
|
||||||
"""Chat message history."""
|
"""Chat message history."""
|
||||||
|
|
||||||
timestamp: float
|
|
||||||
"""Timestamp of last use in seconds."""
|
|
||||||
|
|
||||||
messages: list[ollama.Message]
|
messages: list[ollama.Message]
|
||||||
"""List of message history, including system prompt and assistant responses."""
|
"""List of message history, including system prompt and assistant responses."""
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_unknown_hass_api
|
# name: test_unknown_hass_api
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': '1234',
|
||||||
'response': IntentResponse(
|
'response': IntentResponse(
|
||||||
card=dict({
|
card=dict({
|
||||||
}),
|
}),
|
||||||
@ -20,7 +20,7 @@
|
|||||||
speech=dict({
|
speech=dict({
|
||||||
'plain': dict({
|
'plain': dict({
|
||||||
'extra_data': None,
|
'extra_data': None,
|
||||||
'speech': 'Error preparing LLM API: API non-existing not found',
|
'speech': 'Error preparing LLM API',
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
speech_slots=dict({
|
speech_slots=dict({
|
||||||
|
@ -325,7 +325,11 @@ async def test_unknown_hass_api(
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
result = await conversation.async_converse(
|
result = await conversation.async_converse(
|
||||||
hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id
|
hass,
|
||||||
|
"hello",
|
||||||
|
"1234",
|
||||||
|
Context(),
|
||||||
|
agent_id=mock_config_entry.entry_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result == snapshot
|
assert result == snapshot
|
||||||
@ -428,70 +432,17 @@ async def test_message_history_trimming(
|
|||||||
assert args[4].kwargs["messages"][5]["content"] == "message 5"
|
assert args[4].kwargs["messages"][5]["content"] == "message 5"
|
||||||
|
|
||||||
|
|
||||||
async def test_message_history_pruning(
|
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
|
||||||
) -> None:
|
|
||||||
"""Test that old message histories are pruned."""
|
|
||||||
with patch(
|
|
||||||
"ollama.AsyncClient.chat",
|
|
||||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
|
||||||
):
|
|
||||||
# Create 3 different message histories
|
|
||||||
conversation_ids: list[str] = []
|
|
||||||
for i in range(3):
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
f"message {i + 1}",
|
|
||||||
conversation_id=None,
|
|
||||||
context=Context(),
|
|
||||||
agent_id=mock_config_entry.entry_id,
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
|
||||||
), result
|
|
||||||
assert isinstance(result.conversation_id, str)
|
|
||||||
conversation_ids.append(result.conversation_id)
|
|
||||||
|
|
||||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
|
||||||
mock_config_entry.entry_id
|
|
||||||
)
|
|
||||||
assert len(agent._history) == 3
|
|
||||||
assert agent._history.keys() == set(conversation_ids)
|
|
||||||
|
|
||||||
# Modify the timestamps of the first 2 histories so they will be pruned
|
|
||||||
# on the next cycle.
|
|
||||||
for conversation_id in conversation_ids[:2]:
|
|
||||||
# Move back 2 hours
|
|
||||||
agent._history[conversation_id].timestamp -= 2 * 60 * 60
|
|
||||||
|
|
||||||
# Next cycle
|
|
||||||
result = await conversation.async_converse(
|
|
||||||
hass,
|
|
||||||
"test message",
|
|
||||||
conversation_id=None,
|
|
||||||
context=Context(),
|
|
||||||
agent_id=mock_config_entry.entry_id,
|
|
||||||
)
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ACTION_DONE, (
|
|
||||||
result
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only the most recent histories should remain
|
|
||||||
assert len(agent._history) == 2
|
|
||||||
assert conversation_ids[-1] in agent._history
|
|
||||||
assert result.conversation_id in agent._history
|
|
||||||
|
|
||||||
|
|
||||||
async def test_message_history_unlimited(
|
async def test_message_history_unlimited(
|
||||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that message history is not trimmed when max_history = 0."""
|
"""Test that message history is not trimmed when max_history = 0."""
|
||||||
conversation_id = "1234"
|
conversation_id = "1234"
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"ollama.AsyncClient.chat",
|
"ollama.AsyncClient.chat",
|
||||||
return_value={"message": {"role": "assistant", "content": "test response"}},
|
return_value={"message": {"role": "assistant", "content": "test response"}},
|
||||||
),
|
) as mock_chat,
|
||||||
):
|
):
|
||||||
hass.config_entries.async_update_entry(
|
hass.config_entries.async_update_entry(
|
||||||
mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0}
|
mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0}
|
||||||
@ -508,13 +459,13 @@ async def test_message_history_unlimited(
|
|||||||
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
result.response.response_type == intent.IntentResponseType.ACTION_DONE
|
||||||
), result
|
), result
|
||||||
|
|
||||||
agent = conversation.get_agent_manager(hass).async_get_agent(
|
args = mock_chat.call_args_list
|
||||||
mock_config_entry.entry_id
|
assert len(args) == 100
|
||||||
|
recorded_messages = args[-1].kwargs["messages"]
|
||||||
|
message_count = sum(
|
||||||
|
(message["role"] == "user") for message in recorded_messages
|
||||||
)
|
)
|
||||||
|
assert message_count == 100
|
||||||
assert len(agent._history) == 1
|
|
||||||
assert conversation_id in agent._history
|
|
||||||
assert agent._history[conversation_id].num_user_messages == 100
|
|
||||||
|
|
||||||
|
|
||||||
async def test_error_handling(
|
async def test_error_handling(
|
||||||
|
Reference in New Issue
Block a user