Add latest changes from Google subentries

This commit is contained in:
Paulus Schoutsen
2025-06-24 02:32:50 +00:00
parent 0c1c865ab3
commit 62ac530996
7 changed files with 337 additions and 40 deletions

View File

@ -12,7 +12,12 @@ from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_URL, Platform
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers import (
config_validation as cv,
device_registry as dr,
entity_registry as er,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.util.ssl import get_default_context
from .const import (
@ -22,9 +27,9 @@ from .const import (
CONF_NUM_CTX,
CONF_PROMPT,
CONF_THINK,
DEFAULT_CONVERSATION_NAME,
DEFAULT_TIMEOUT,
DOMAIN,
OllamaConfigEntry,
)
_LOGGER = logging.getLogger(__name__)
@ -44,7 +49,13 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION,)
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up Ollama."""
await async_migrate_integration(hass)
return True
async def async_setup_entry(hass: HomeAssistant, entry: OllamaConfigEntry) -> bool:
"""Set up Ollama from a config entry."""
settings = {**entry.data, **entry.options}
client = ollama.AsyncClient(host=settings[CONF_URL], verify=get_default_context())
@ -54,8 +65,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
except (TimeoutError, httpx.ConnectError) as err:
raise ConfigEntryNotReady(err) from err
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
entry.runtime_data = client
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True
@ -64,45 +74,69 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Ollama."""
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
return False
hass.data[DOMAIN].pop(entry.entry_id)
return True
async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Migrate old entry."""
if entry.version == 1:
# Migrate from version 1 to version 2
# Move conversation-specific options to a subentry
async def async_migrate_integration(hass: HomeAssistant) -> None:
"""Migrate integration entry structure."""
entries = hass.config_entries.async_entries(DOMAIN)
if not any(entry.version == 1 for entry in entries):
return
api_keys_entries: dict[str, ConfigEntry] = {}
entity_registry = er.async_get(hass)
device_registry = dr.async_get(hass)
for entry in entries:
use_existing = False
subentry = ConfigSubentry(
data=entry.options,
subentry_type="conversation",
title=DEFAULT_CONVERSATION_NAME,
title=entry.title,
unique_id=None,
)
hass.config_entries.async_add_subentry(
entry,
subentry,
)
if entry.data[CONF_URL] not in api_keys_entries:
use_existing = True
api_keys_entries[entry.data[CONF_URL]] = entry
# Migrate conversation entity to be linked to subentry
ent_reg = er.async_get(hass)
conversation_entity = ent_reg.async_get_entity_id(
parent_entry = api_keys_entries[entry.data[CONF_URL]]
hass.config_entries.async_add_subentry(parent_entry, subentry)
conversation_entity = entity_registry.async_get_entity_id(
"conversation",
DOMAIN,
entry.entry_id,
)
if conversation_entity is not None:
ent_reg.async_update_entity(
entity_registry.async_update_entity(
conversation_entity,
config_entry_id=parent_entry.entry_id,
config_subentry_id=subentry.subentry_id,
new_unique_id=subentry.subentry_id,
)
# Remove options from the main entry
hass.config_entries.async_update_entry(
entry,
options={},
version=2,
device = device_registry.async_get_device(
identifiers={(DOMAIN, entry.entry_id)}
)
if device is not None:
device_registry.async_update_device(
device.id,
new_identifiers={(DOMAIN, subentry.subentry_id)},
add_config_subentry_id=subentry.subentry_id,
add_config_entry_id=parent_entry.entry_id,
)
if parent_entry.entry_id != entry.entry_id:
device_registry.async_update_device(
device.id,
remove_config_entry_id=entry.entry_id,
)
return True
if not use_existing:
await hass.config_entries.async_remove(entry.entry_id)
else:
hass.config_entries.async_update_entry(
entry,
options={},
version=2,
)

View File

@ -14,6 +14,7 @@ import voluptuous as vol
from homeassistant.config_entries import (
ConfigEntry,
ConfigEntryState,
ConfigFlow,
ConfigFlowResult,
ConfigSubentryFlow,
@ -44,7 +45,6 @@ from .const import (
CONF_NUM_CTX,
CONF_PROMPT,
CONF_THINK,
DEFAULT_CONVERSATION_NAME,
DEFAULT_KEEP_ALIVE,
DEFAULT_MAX_HISTORY,
DEFAULT_MODEL,
@ -96,6 +96,8 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
errors = {}
self._async_abort_entries_match({CONF_URL: self.url})
try:
self.client = ollama.AsyncClient(
host=self.url, verify=get_default_context()
@ -148,13 +150,13 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
return await self.async_step_download()
return self.async_create_entry(
title=_get_title(self.model),
title=self.url,
data={CONF_URL: self.url, CONF_MODEL: self.model},
subentries=[
{
"subentry_type": "conversation",
"data": {},
"title": DEFAULT_CONVERSATION_NAME,
"title": _get_title(self.model),
"unique_id": None,
}
],
@ -203,7 +205,7 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
{
"subentry_type": "conversation",
"data": {},
"title": DEFAULT_CONVERSATION_NAME,
"title": _get_title(self.model),
"unique_id": None,
}
],
@ -236,6 +238,10 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Set conversation options."""
# abort if entry is not loaded
if self._get_entry().state != ConfigEntryState.LOADED:
return self.async_abort(reason="entry_not_loaded")
errors: dict[str, str] = {}
if user_input is None:
@ -279,7 +285,7 @@ def ollama_config_option_schema(
if is_new:
schema: dict[vol.Required | vol.Optional, Any] = {
vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME): str,
vol.Required(CONF_NAME, default="Ollama Conversation"): str,
}
else:
schema = {}

View File

@ -1,6 +1,11 @@
"""Constants for the Ollama integration."""
import ollama
from homeassistant.config_entries import ConfigEntry
DOMAIN = "ollama"
type OllamaConfigEntry = ConfigEntry[ollama.AsyncClient]
CONF_MODEL = "model"
CONF_PROMPT = "prompt"
@ -157,5 +162,3 @@ MODEL_NAMES = [ # https://ollama.com/library
"zephyr",
]
DEFAULT_MODEL = "llama3.2:latest"
DEFAULT_CONVERSATION_NAME = "Ollama Conversation"

View File

@ -15,7 +15,7 @@ from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import intent, llm
from homeassistant.helpers import device_registry as dr, intent, llm
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .const import (
@ -188,6 +188,13 @@ class OllamaConversationEntity(
self.subentry = subentry
self._attr_name = subentry.title
self._attr_unique_id = subentry.subentry_id
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, subentry.subentry_id)},
name=subentry.title,
manufacturer="Ollama",
model=entry.data[CONF_MODEL],
entry_type=dr.DeviceEntryType.SERVICE,
)
if self.subentry.data.get(CONF_LLM_HASS_API):
self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL
@ -254,7 +261,7 @@ class OllamaConversationEntity(
"""Generate an answer for the chat log."""
settings = {**self.entry.data, **self.subentry.data}
client = self.hass.data[DOMAIN][self.entry.entry_id]
client = self.entry.runtime_data
model = settings[CONF_MODEL]
tools: list[dict[str, Any]] | None = None

View File

@ -12,7 +12,8 @@
}
},
"abort": {
"download_failed": "Model downloading failed"
"download_failed": "Model downloading failed",
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]"
},
"error": {
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
@ -49,7 +50,8 @@
}
},
"abort": {
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
"entry_not_loaded": "Cannot add things while the configuration is disabled."
}
}
}

View File

@ -63,6 +63,37 @@ async def test_form(hass: HomeAssistant) -> None:
assert len(mock_setup_entry.mock_calls) == 1
async def test_duplicate_entry(hass: HomeAssistant) -> None:
"""Test we abort on duplicate config entry."""
MockConfigEntry(
domain=ollama.DOMAIN,
data={
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: "test_model",
},
).add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert not result["errors"]
with patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
return_value={"models": [{"model": "test_model"}]},
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
ollama.CONF_URL: "http://localhost:11434",
},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "already_configured"
async def test_form_need_download(hass: HomeAssistant) -> None:
"""Test flow when a model needs to be downloaded."""
# Pretend we already set up a config entry.
@ -163,7 +194,7 @@ async def test_subentry_options(
# Test reconfiguration
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_type, subentry.subentry_id
hass, subentry.subentry_id
)
assert options_flow["type"] is FlowResultType.FORM
@ -190,6 +221,22 @@ async def test_subentry_options(
}
async def test_creating_conversation_subentry_not_loaded(
hass: HomeAssistant,
mock_init_component,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating a conversation subentry when entry is not loaded."""
await hass.config_entries.async_unload(mock_config_entry.entry_id)
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "entry_not_loaded"
@pytest.mark.parametrize(
("side_effect", "error"),
[

View File

@ -6,7 +6,7 @@ from httpx import ConnectError
import pytest
from homeassistant.components import ollama
from homeassistant.components.ollama.const import DEFAULT_CONVERSATION_NAME, DOMAIN
from homeassistant.components.ollama.const import DOMAIN
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.setup import async_setup_component
@ -56,11 +56,20 @@ async def test_migration_from_v1_to_v2(
)
mock_config_entry.add_to_hass(hass)
device = device_registry.async_get_or_create(
config_entry_id=mock_config_entry.entry_id,
identifiers={(DOMAIN, mock_config_entry.entry_id)},
name=mock_config_entry.title,
manufacturer="Ollama",
model="Ollama",
entry_type=dr.DeviceEntryType.SERVICE,
)
entity = entity_registry.async_get_or_create(
"conversation",
DOMAIN,
mock_config_entry.entry_id,
config_entry=mock_config_entry,
device_id=device.id,
suggested_object_id="llama_3_2_8b",
)
@ -79,7 +88,7 @@ async def test_migration_from_v1_to_v2(
subentry = next(iter(mock_config_entry.subentries.values()))
assert subentry.unique_id is None
assert subentry.title == DEFAULT_CONVERSATION_NAME
assert subentry.title == "llama-3.2-8b"
assert subentry.subentry_type == "conversation"
assert subentry.data == TEST_OPTIONS
@ -87,3 +96,192 @@ async def test_migration_from_v1_to_v2(
assert migrated_entity is not None
assert migrated_entity.config_entry_id == mock_config_entry.entry_id
assert migrated_entity.config_subentry_id == subentry.subentry_id
assert migrated_entity.unique_id == subentry.subentry_id
# Check device migration
assert not device_registry.async_get_device(
identifiers={(DOMAIN, mock_config_entry.entry_id)}
)
assert (
migrated_device := device_registry.async_get_device(
identifiers={(DOMAIN, subentry.subentry_id)}
)
)
assert migrated_device.identifiers == {(DOMAIN, subentry.subentry_id)}
assert migrated_device.id == device.id
async def test_migration_from_v1_to_v2_with_multiple_urls(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 1 to version 2 with different URLs."""
# Create two v1 config entries with different URLs
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
options=TEST_OPTIONS,
version=1,
title="Ollama 1",
)
mock_config_entry.add_to_hass(hass)
mock_config_entry_2 = MockConfigEntry(
domain=DOMAIN,
data={"url": "http://localhost:11435", "model": "llama3.2:latest"},
options=TEST_OPTIONS,
version=1,
title="Ollama 2",
)
mock_config_entry_2.add_to_hass(hass)
device = device_registry.async_get_or_create(
config_entry_id=mock_config_entry.entry_id,
identifiers={(DOMAIN, mock_config_entry.entry_id)},
name=mock_config_entry.title,
manufacturer="Ollama",
model="Ollama 1",
entry_type=dr.DeviceEntryType.SERVICE,
)
entity_registry.async_get_or_create(
"conversation",
DOMAIN,
mock_config_entry.entry_id,
config_entry=mock_config_entry,
device_id=device.id,
suggested_object_id="ollama_1",
)
device_2 = device_registry.async_get_or_create(
config_entry_id=mock_config_entry_2.entry_id,
identifiers={(DOMAIN, mock_config_entry_2.entry_id)},
name=mock_config_entry_2.title,
manufacturer="Ollama",
model="Ollama 2",
entry_type=dr.DeviceEntryType.SERVICE,
)
entity_registry.async_get_or_create(
"conversation",
DOMAIN,
mock_config_entry_2.entry_id,
config_entry=mock_config_entry_2,
device_id=device_2.id,
suggested_object_id="ollama_2",
)
# Run migration
with patch(
"homeassistant.components.ollama.async_setup_entry",
return_value=True,
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 2
for idx, entry in enumerate(entries):
assert entry.version == 2
assert not entry.options
assert len(entry.subentries) == 1
subentry = list(entry.subentries.values())[0]
assert subentry.subentry_type == "conversation"
assert subentry.data == TEST_OPTIONS
assert subentry.title == f"Ollama {idx + 1}"
dev = device_registry.async_get_device(
identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)}
)
assert dev is not None
async def test_migration_from_v1_to_v2_with_same_urls(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 1 to version 2 with same URLs consolidates entries."""
# Create two v1 config entries with the same URL
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={"url": "http://localhost:11434", "model": "llama3.2:latest"},
options=TEST_OPTIONS,
version=1,
title="Ollama",
)
mock_config_entry.add_to_hass(hass)
mock_config_entry_2 = MockConfigEntry(
domain=DOMAIN,
data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, # Same URL
options=TEST_OPTIONS,
version=1,
title="Ollama 2",
)
mock_config_entry_2.add_to_hass(hass)
device = device_registry.async_get_or_create(
config_entry_id=mock_config_entry.entry_id,
identifiers={(DOMAIN, mock_config_entry.entry_id)},
name=mock_config_entry.title,
manufacturer="Ollama",
model="Ollama",
entry_type=dr.DeviceEntryType.SERVICE,
)
entity_registry.async_get_or_create(
"conversation",
DOMAIN,
mock_config_entry.entry_id,
config_entry=mock_config_entry,
device_id=device.id,
suggested_object_id="ollama",
)
device_2 = device_registry.async_get_or_create(
config_entry_id=mock_config_entry_2.entry_id,
identifiers={(DOMAIN, mock_config_entry_2.entry_id)},
name=mock_config_entry_2.title,
manufacturer="Ollama",
model="Ollama",
entry_type=dr.DeviceEntryType.SERVICE,
)
entity_registry.async_get_or_create(
"conversation",
DOMAIN,
mock_config_entry_2.entry_id,
config_entry=mock_config_entry_2,
device_id=device_2.id,
suggested_object_id="ollama_2",
)
# Run migration
with patch(
"homeassistant.components.ollama.async_setup_entry",
return_value=True,
):
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
# Should have only one entry left (consolidated)
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.version == 2
assert not entry.options
assert len(entry.subentries) == 2 # Two subentries from the two original entries
# Check both subentries exist with correct data
subentries = list(entry.subentries.values())
titles = [sub.title for sub in subentries]
assert "Ollama" in titles
assert "Ollama 2" in titles
for subentry in subentries:
assert subentry.subentry_type == "conversation"
assert subentry.data == TEST_OPTIONS
# Check devices were migrated correctly
dev = device_registry.async_get_device(
identifiers={(DOMAIN, subentry.subentry_id)}
)
assert dev is not None