Convert Ollama to subentries

This commit is contained in:
Paulus Schoutsen
2025-06-22 01:54:36 +00:00
parent 56f4039ac2
commit 0c1c865ab3
9 changed files with 320 additions and 128 deletions

View File

@@ -8,11 +8,11 @@ import logging
import httpx import httpx
import ollama import ollama
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_URL, Platform from homeassistant.const import CONF_URL, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.util.ssl import get_default_context from homeassistant.util.ssl import get_default_context
from .const import ( from .const import (
@@ -22,6 +22,7 @@ from .const import (
CONF_NUM_CTX, CONF_NUM_CTX,
CONF_PROMPT, CONF_PROMPT,
CONF_THINK, CONF_THINK,
DEFAULT_CONVERSATION_NAME,
DEFAULT_TIMEOUT, DEFAULT_TIMEOUT,
DOMAIN, DOMAIN,
) )
@@ -65,3 +66,43 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return False return False
hass.data[DOMAIN].pop(entry.entry_id) hass.data[DOMAIN].pop(entry.entry_id)
return True return True
async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Migrate old entry."""
if entry.version == 1:
# Migrate from version 1 to version 2
# Move conversation-specific options to a subentry
subentry = ConfigSubentry(
data=entry.options,
subentry_type="conversation",
title=DEFAULT_CONVERSATION_NAME,
unique_id=None,
)
hass.config_entries.async_add_subentry(
entry,
subentry,
)
# Migrate conversation entity to be linked to subentry
ent_reg = er.async_get(hass)
conversation_entity = ent_reg.async_get_entity_id(
"conversation",
DOMAIN,
entry.entry_id,
)
if conversation_entity is not None:
ent_reg.async_update_entity(
conversation_entity,
config_subentry_id=subentry.subentry_id,
new_unique_id=subentry.subentry_id,
)
# Remove options from the main entry
hass.config_entries.async_update_entry(
entry,
options={},
version=2,
)
return True

View File

@@ -16,10 +16,11 @@ from homeassistant.config_entries import (
ConfigEntry, ConfigEntry,
ConfigFlow, ConfigFlow,
ConfigFlowResult, ConfigFlowResult,
OptionsFlow, ConfigSubentryFlow,
SubentryFlowResult,
) )
from homeassistant.const import CONF_LLM_HASS_API, CONF_URL from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import llm from homeassistant.helpers import llm
from homeassistant.helpers.selector import ( from homeassistant.helpers.selector import (
BooleanSelector, BooleanSelector,
@@ -43,6 +44,7 @@ from .const import (
CONF_NUM_CTX, CONF_NUM_CTX,
CONF_PROMPT, CONF_PROMPT,
CONF_THINK, CONF_THINK,
DEFAULT_CONVERSATION_NAME,
DEFAULT_KEEP_ALIVE, DEFAULT_KEEP_ALIVE,
DEFAULT_MAX_HISTORY, DEFAULT_MAX_HISTORY,
DEFAULT_MODEL, DEFAULT_MODEL,
@@ -70,7 +72,7 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Ollama.""" """Handle a config flow for Ollama."""
VERSION = 1 VERSION = 2
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize config flow.""" """Initialize config flow."""
@@ -148,6 +150,14 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_create_entry( return self.async_create_entry(
title=_get_title(self.model), title=_get_title(self.model),
data={CONF_URL: self.url, CONF_MODEL: self.model}, data={CONF_URL: self.url, CONF_MODEL: self.model},
subentries=[
{
"subentry_type": "conversation",
"data": {},
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
],
) )
async def async_step_download( async def async_step_download(
@@ -189,6 +199,14 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_create_entry( return self.async_create_entry(
title=_get_title(self.model), title=_get_title(self.model),
data={CONF_URL: self.url, CONF_MODEL: self.model}, data={CONF_URL: self.url, CONF_MODEL: self.model},
subentries=[
{
"subentry_type": "conversation",
"data": {},
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
],
) )
async def async_step_failed( async def async_step_failed(
@@ -197,41 +215,58 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
"""Step after model downloading has failed.""" """Step after model downloading has failed."""
return self.async_abort(reason="download_failed") return self.async_abort(reason="download_failed")
@staticmethod @classmethod
def async_get_options_flow( @callback
config_entry: ConfigEntry, def async_get_supported_subentry_types(
) -> OptionsFlow: cls, config_entry: ConfigEntry
"""Create the options flow.""" ) -> dict[str, type[ConfigSubentryFlow]]:
return OllamaOptionsFlow(config_entry) """Return subentries supported by this integration."""
return {"conversation": ConversationSubentryFlowHandler}
class OllamaOptionsFlow(OptionsFlow): class ConversationSubentryFlowHandler(ConfigSubentryFlow):
"""Ollama options flow.""" """Flow for managing conversation subentries."""
def __init__(self, config_entry: ConfigEntry) -> None: @property
"""Initialize options flow.""" def _is_new(self) -> bool:
self.url: str = config_entry.data[CONF_URL] """Return if this is a new subentry."""
self.model: str = config_entry.data[CONF_MODEL] return self.source == "user"
async def async_step_init( async def async_step_set_options(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult: ) -> SubentryFlowResult:
"""Manage the options.""" """Set conversation options."""
if user_input is not None: errors: dict[str, str] = {}
if user_input is None:
if self._is_new:
options = {}
else:
options = self._get_reconfigure_subentry().data.copy()
elif self._is_new:
return self.async_create_entry( return self.async_create_entry(
title=_get_title(self.model), data=user_input title=user_input.pop(CONF_NAME),
data=user_input,
)
else:
return self.async_update_and_abort(
self._get_entry(),
self._get_reconfigure_subentry(),
data=user_input,
) )
options: Mapping[str, Any] = self.config_entry.options or {} schema = ollama_config_option_schema(self.hass, self._is_new, options)
schema = ollama_config_option_schema(self.hass, options)
return self.async_show_form( return self.async_show_form(
step_id="init", step_id="set_options", data_schema=vol.Schema(schema), errors=errors
data_schema=vol.Schema(schema),
) )
async_step_user = async_step_set_options
async_step_reconfigure = async_step_set_options
def ollama_config_option_schema( def ollama_config_option_schema(
hass: HomeAssistant, options: Mapping[str, Any] hass: HomeAssistant, is_new: bool, options: Mapping[str, Any]
) -> dict: ) -> dict:
"""Ollama options schema.""" """Ollama options schema."""
hass_apis: list[SelectOptionDict] = [ hass_apis: list[SelectOptionDict] = [
@@ -242,54 +277,72 @@ def ollama_config_option_schema(
for api in llm.async_get_apis(hass) for api in llm.async_get_apis(hass)
] ]
return { if is_new:
vol.Optional( schema: dict[vol.Required | vol.Optional, Any] = {
CONF_PROMPT, vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME): str,
description={ }
"suggested_value": options.get( else:
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT schema = {}
schema.update(
{
vol.Optional(
CONF_PROMPT,
description={
"suggested_value": options.get(
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
)
},
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
vol.Optional(
CONF_NUM_CTX,
description={
"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)
},
): NumberSelector(
NumberSelectorConfig(
min=MIN_NUM_CTX,
max=MAX_NUM_CTX,
step=1,
mode=NumberSelectorMode.BOX,
) )
}, ),
): TemplateSelector(), vol.Optional(
vol.Optional( CONF_MAX_HISTORY,
CONF_LLM_HASS_API, description={
description={"suggested_value": options.get(CONF_LLM_HASS_API)}, "suggested_value": options.get(
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)), CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY
vol.Optional( )
CONF_NUM_CTX, },
description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)}, ): NumberSelector(
): NumberSelector( NumberSelectorConfig(
NumberSelectorConfig( min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
min=MIN_NUM_CTX, max=MAX_NUM_CTX, step=1, mode=NumberSelectorMode.BOX )
) ),
), vol.Optional(
vol.Optional( CONF_KEEP_ALIVE,
CONF_MAX_HISTORY, description={
description={ "suggested_value": options.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)
"suggested_value": options.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY) },
}, ): NumberSelector(
): NumberSelector( NumberSelectorConfig(
NumberSelectorConfig( min=-1, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX )
) ),
), vol.Optional(
vol.Optional( CONF_THINK,
CONF_KEEP_ALIVE, description={
description={ "suggested_value": options.get("think", DEFAULT_THINK),
"suggested_value": options.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE) },
}, ): BooleanSelector(),
): NumberSelector( }
NumberSelectorConfig( )
min=-1, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
) return schema
),
vol.Optional(
CONF_THINK,
description={
"suggested_value": options.get("think", DEFAULT_THINK),
},
): BooleanSelector(),
}
def _get_title(model: str) -> str: def _get_title(model: str) -> str:

View File

@@ -157,3 +157,5 @@ MODEL_NAMES = [ # https://ollama.com/library
"zephyr", "zephyr",
] ]
DEFAULT_MODEL = "llama3.2:latest" DEFAULT_MODEL = "llama3.2:latest"
DEFAULT_CONVERSATION_NAME = "Ollama Conversation"

View File

@@ -11,7 +11,7 @@ import ollama
from voluptuous_openapi import convert from voluptuous_openapi import convert
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -44,8 +44,14 @@ async def async_setup_entry(
async_add_entities: AddConfigEntryEntitiesCallback, async_add_entities: AddConfigEntryEntitiesCallback,
) -> None: ) -> None:
"""Set up conversation entities.""" """Set up conversation entities."""
agent = OllamaConversationEntity(config_entry) for subentry in config_entry.subentries.values():
async_add_entities([agent]) if subentry.subentry_type != "conversation":
continue
async_add_entities(
[OllamaConversationEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
def _format_tool( def _format_tool(
@@ -174,17 +180,15 @@ class OllamaConversationEntity(
): ):
"""Ollama conversation agent.""" """Ollama conversation agent."""
_attr_has_entity_name = True
_attr_supports_streaming = True _attr_supports_streaming = True
def __init__(self, entry: ConfigEntry) -> None: def __init__(self, entry: ConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the agent.""" """Initialize the agent."""
self.entry = entry self.entry = entry
self.subentry = subentry
# conversation id -> message history self._attr_name = subentry.title
self._attr_name = entry.title self._attr_unique_id = subentry.subentry_id
self._attr_unique_id = entry.entry_id if self.subentry.data.get(CONF_LLM_HASS_API):
if self.entry.options.get(CONF_LLM_HASS_API):
self._attr_supported_features = ( self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL conversation.ConversationEntityFeature.CONTROL
) )
@@ -216,7 +220,7 @@ class OllamaConversationEntity(
chat_log: conversation.ChatLog, chat_log: conversation.ChatLog,
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Call the API.""" """Call the API."""
settings = {**self.entry.data, **self.entry.options} settings = {**self.entry.data, **self.subentry.data}
try: try:
await chat_log.async_provide_llm_data( await chat_log.async_provide_llm_data(
@@ -248,7 +252,7 @@ class OllamaConversationEntity(
chat_log: conversation.ChatLog, chat_log: conversation.ChatLog,
) -> None: ) -> None:
"""Generate an answer for the chat log.""" """Generate an answer for the chat log."""
settings = {**self.entry.data, **self.entry.options} settings = {**self.entry.data, **self.subentry.data}
client = self.hass.data[DOMAIN][self.entry.entry_id] client = self.hass.data[DOMAIN][self.entry.entry_id]
model = settings[CONF_MODEL] model = settings[CONF_MODEL]

View File

@@ -22,23 +22,34 @@
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details." "download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
} }
}, },
"options": { "config_subentries": {
"step": { "conversation": {
"init": { "initiate_flow": {
"data": { "user": "Add conversation agent",
"prompt": "Instructions", "reconfigure": "Reconfigure conversation agent"
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", },
"max_history": "Max history messages", "entry_type": "Conversation agent",
"num_ctx": "Context window size", "step": {
"keep_alive": "Keep alive", "set_options": {
"think": "Think before responding" "data": {
}, "name": "[%key:common::config_flow::data::name%]",
"data_description": { "prompt": "Instructions",
"prompt": "Instruct how the LLM should respond. This can be a template.", "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never.", "max_history": "Max history messages",
"num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.", "num_ctx": "Context window size",
"think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency." "keep_alive": "Keep alive",
"think": "Think before responding"
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template.",
"keep_alive": "Duration in seconds for Ollama to keep model in memory. -1 = indefinite, 0 = never.",
"num_ctx": "Maximum number of text tokens the model can process. Lower to reduce Ollama RAM, or increase for a large number of exposed entities.",
"think": "If enabled, the LLM will think before responding. This can improve response quality but may increase latency."
}
} }
},
"abort": {
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
} }
} }
} }

View File

@@ -30,7 +30,15 @@ def mock_config_entry(
entry = MockConfigEntry( entry = MockConfigEntry(
domain=ollama.DOMAIN, domain=ollama.DOMAIN,
data=TEST_USER_DATA, data=TEST_USER_DATA,
options=mock_config_entry_options, version=2,
subentries_data=[
{
"data": mock_config_entry_options,
"subentry_type": "conversation",
"title": "Ollama Conversation",
"unique_id": None,
}
],
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
return entry return entry
@@ -41,8 +49,10 @@ def mock_config_entry_with_assist(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> MockConfigEntry: ) -> MockConfigEntry:
"""Mock a config entry with assist.""" """Mock a config entry with assist."""
hass.config_entries.async_update_entry( hass.config_entries.async_update_subentry(
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} mock_config_entry,
next(iter(mock_config_entry.subentries.values())),
data={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
) )
return mock_config_entry return mock_config_entry

View File

@@ -155,14 +155,21 @@ async def test_form_need_download(hass: HomeAssistant) -> None:
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
async def test_options( async def test_subentry_options(
hass: HomeAssistant, mock_config_entry, mock_init_component hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None: ) -> None:
"""Test the options form.""" """Test the subentry options form."""
options_flow = await hass.config_entries.options.async_init( subentry = next(iter(mock_config_entry.subentries.values()))
mock_config_entry.entry_id
# Test reconfiguration
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_type, subentry.subentry_id
) )
options = await hass.config_entries.options.async_configure(
assert options_flow["type"] is FlowResultType.FORM
assert options_flow["step_id"] == "set_options"
options = await hass.config_entries.subentries.async_configure(
options_flow["flow_id"], options_flow["flow_id"],
{ {
ollama.CONF_PROMPT: "test prompt", ollama.CONF_PROMPT: "test prompt",
@@ -172,8 +179,10 @@ async def test_options(
}, },
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert options["type"] is FlowResultType.CREATE_ENTRY
assert options["data"] == { assert options["type"] is FlowResultType.ABORT
assert options["reason"] == "reconfigure_successful"
assert subentry.data == {
ollama.CONF_PROMPT: "test prompt", ollama.CONF_PROMPT: "test prompt",
ollama.CONF_MAX_HISTORY: 100, ollama.CONF_MAX_HISTORY: 100,
ollama.CONF_NUM_CTX: 32768, ollama.CONF_NUM_CTX: 32768,

View File

@@ -35,7 +35,7 @@ async def stream_generator(response: dict | list[dict]) -> AsyncGenerator[dict]:
yield msg yield msg
@pytest.mark.parametrize("agent_id", [None, "conversation.mock_title"]) @pytest.mark.parametrize("agent_id", [None, "conversation.ollama_conversation"])
async def test_chat( async def test_chat(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
@@ -149,9 +149,11 @@ async def test_template_variables(
mock_user.id = "12345" mock_user.id = "12345"
mock_user.name = "Test User" mock_user.name = "Test User"
hass.config_entries.async_update_entry( subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry(
mock_config_entry, mock_config_entry,
options={ subentry,
data={
"prompt": ( "prompt": (
"The user name is {{ user_name }}. " "The user name is {{ user_name }}. "
"The user id is {{ llm_context.context.user_id }}." "The user id is {{ llm_context.context.user_id }}."
@@ -382,10 +384,12 @@ async def test_unknown_hass_api(
mock_init_component, mock_init_component,
) -> None: ) -> None:
"""Test when we reference an API that no longer exists.""" """Test when we reference an API that no longer exists."""
hass.config_entries.async_update_entry( subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry(
mock_config_entry, mock_config_entry,
options={ subentry,
**mock_config_entry.options, data={
**subentry.data,
CONF_LLM_HASS_API: "non-existing", CONF_LLM_HASS_API: "non-existing",
}, },
) )
@@ -518,8 +522,9 @@ async def test_message_history_unlimited(
with ( with (
patch("ollama.AsyncClient.chat", side_effect=stream) as mock_chat, patch("ollama.AsyncClient.chat", side_effect=stream) as mock_chat,
): ):
hass.config_entries.async_update_entry( subentry = next(iter(mock_config_entry.subentries.values()))
mock_config_entry, options={ollama.CONF_MAX_HISTORY: 0} hass.config_entries.async_update_subentry(
mock_config_entry, subentry, data={ollama.CONF_MAX_HISTORY: 0}
) )
for i in range(100): for i in range(100):
result = await conversation.async_converse( result = await conversation.async_converse(
@@ -563,9 +568,11 @@ async def test_template_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None: ) -> None:
"""Test that template error handling works.""" """Test that template error handling works."""
hass.config_entries.async_update_entry( subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry(
mock_config_entry, mock_config_entry,
options={ subentry,
data={
"prompt": "talk like a {% if True %}smarthome{% else %}pirate please.", "prompt": "talk like a {% if True %}smarthome{% else %}pirate please.",
}, },
) )
@@ -593,7 +600,7 @@ async def test_conversation_agent(
) )
assert agent.supported_languages == MATCH_ALL assert agent.supported_languages == MATCH_ALL
state = hass.states.get("conversation.mock_title") state = hass.states.get("conversation.ollama_conversation")
assert state assert state
assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0 assert state.attributes[ATTR_SUPPORTED_FEATURES] == 0
@@ -609,7 +616,7 @@ async def test_conversation_agent_with_assist(
) )
assert agent.supported_languages == MATCH_ALL assert agent.supported_languages == MATCH_ALL
state = hass.states.get("conversation.mock_title") state = hass.states.get("conversation.ollama_conversation")
assert state assert state
assert ( assert (
state.attributes[ATTR_SUPPORTED_FEATURES] state.attributes[ATTR_SUPPORTED_FEATURES]
@@ -642,7 +649,7 @@ async def test_options(
"test message", "test message",
None, None,
Context(), Context(),
agent_id="conversation.mock_title", agent_id="conversation.ollama_conversation",
) )
assert mock_chat.call_count == 1 assert mock_chat.call_count == 1
@@ -667,9 +674,11 @@ async def test_reasoning_filter(
entry = MockConfigEntry() entry = MockConfigEntry()
entry.add_to_hass(hass) entry.add_to_hass(hass)
hass.config_entries.async_update_entry( subentry = next(iter(mock_config_entry.subentries.values()))
hass.config_entries.async_update_subentry(
mock_config_entry, mock_config_entry,
options={ subentry,
data={
ollama.CONF_THINK: think, ollama.CONF_THINK: think,
}, },
) )

View File

@@ -6,9 +6,13 @@ from httpx import ConnectError
import pytest import pytest
from homeassistant.components import ollama from homeassistant.components import ollama
from homeassistant.components.ollama.const import DEFAULT_CONVERSATION_NAME, DOMAIN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import TEST_OPTIONS, TEST_USER_DATA
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@@ -34,3 +38,52 @@ async def test_init_error(
assert await async_setup_component(hass, ollama.DOMAIN, {}) assert await async_setup_component(hass, ollama.DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
assert error in caplog.text assert error in caplog.text
async def test_migration_from_v1_to_v2(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 1 to version 2."""
# Create a v1 config entry with conversation options and an entity
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data=TEST_USER_DATA,
options=TEST_OPTIONS,
version=1,
title="llama-3.2-8b",
)
mock_config_entry.add_to_hass(hass)
entity = entity_registry.async_get_or_create(
"conversation",
DOMAIN,
mock_config_entry.entry_id,
config_entry=mock_config_entry,
suggested_object_id="llama_3_2_8b",
)
# Run migration
with patch(
"homeassistant.components.ollama.async_setup_entry",
return_value=True,
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
assert mock_config_entry.version == 2
assert mock_config_entry.data == TEST_USER_DATA
assert mock_config_entry.options == {}
assert len(mock_config_entry.subentries) == 1
subentry = next(iter(mock_config_entry.subentries.values()))
assert subentry.unique_id is None
assert subentry.title == DEFAULT_CONVERSATION_NAME
assert subentry.subentry_type == "conversation"
assert subentry.data == TEST_OPTIONS
migrated_entity = entity_registry.async_get(entity.entity_id)
assert migrated_entity is not None
assert migrated_entity.config_entry_id == mock_config_entry.entry_id
assert migrated_entity.config_subentry_id == subentry.subentry_id