From 62ac530996235aeacf532aadc178a68146ef71c2 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Tue, 24 Jun 2025 02:32:50 +0000 Subject: [PATCH] Add latest changes from Google subentries --- homeassistant/components/ollama/__init__.py | 86 +++++--- .../components/ollama/config_flow.py | 16 +- homeassistant/components/ollama/const.py | 7 +- .../components/ollama/conversation.py | 11 +- homeassistant/components/ollama/strings.json | 6 +- tests/components/ollama/test_config_flow.py | 49 ++++- tests/components/ollama/test_init.py | 202 +++++++++++++++++- 7 files changed, 337 insertions(+), 40 deletions(-) diff --git a/homeassistant/components/ollama/__init__.py b/homeassistant/components/ollama/__init__.py index 48fcbb68a2e..fcd958f6f91 100644 --- a/homeassistant/components/ollama/__init__.py +++ b/homeassistant/components/ollama/__init__.py @@ -12,7 +12,12 @@ from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.const import CONF_URL, Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady -from homeassistant.helpers import config_validation as cv, entity_registry as er +from homeassistant.helpers import ( + config_validation as cv, + device_registry as dr, + entity_registry as er, +) +from homeassistant.helpers.typing import ConfigType from homeassistant.util.ssl import get_default_context from .const import ( @@ -22,9 +27,9 @@ from .const import ( CONF_NUM_CTX, CONF_PROMPT, CONF_THINK, - DEFAULT_CONVERSATION_NAME, DEFAULT_TIMEOUT, DOMAIN, + OllamaConfigEntry, ) _LOGGER = logging.getLogger(__name__) @@ -44,7 +49,13 @@ CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) PLATFORMS = (Platform.CONVERSATION,) -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + """Set up Ollama.""" + await async_migrate_integration(hass) + return True + + +async def async_setup_entry(hass: HomeAssistant, entry: OllamaConfigEntry) -> bool: """Set up Ollama from a config entry.""" settings = {**entry.data, **entry.options} client = ollama.AsyncClient(host=settings[CONF_URL], verify=get_default_context()) @@ -54,8 +65,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: except (TimeoutError, httpx.ConnectError) as err: raise ConfigEntryNotReady(err) from err - hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client - + entry.runtime_data = client await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) return True @@ -64,45 +74,69 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload Ollama.""" if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS): return False - hass.data[DOMAIN].pop(entry.entry_id) return True -async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: - """Migrate old entry.""" - if entry.version == 1: - # Migrate from version 1 to version 2 - # Move conversation-specific options to a subentry +async def async_migrate_integration(hass: HomeAssistant) -> None: + """Migrate integration entry structure.""" + + entries = hass.config_entries.async_entries(DOMAIN) + if not any(entry.version == 1 for entry in entries): + return + + api_keys_entries: dict[str, ConfigEntry] = {} + entity_registry = er.async_get(hass) + device_registry = dr.async_get(hass) + + for entry in entries: + use_existing = False subentry = ConfigSubentry( data=entry.options, subentry_type="conversation", - title=DEFAULT_CONVERSATION_NAME, + title=entry.title, unique_id=None, ) - hass.config_entries.async_add_subentry( - entry, - subentry, - ) + if entry.data[CONF_URL] not in api_keys_entries: + use_existing = True + api_keys_entries[entry.data[CONF_URL]] = entry - # Migrate conversation entity to be linked to subentry - ent_reg = er.async_get(hass) - conversation_entity = ent_reg.async_get_entity_id( + parent_entry = api_keys_entries[entry.data[CONF_URL]] + + hass.config_entries.async_add_subentry(parent_entry, subentry) + conversation_entity = entity_registry.async_get_entity_id( "conversation", DOMAIN, entry.entry_id, ) if conversation_entity is not None: - ent_reg.async_update_entity( + entity_registry.async_update_entity( conversation_entity, + config_entry_id=parent_entry.entry_id, config_subentry_id=subentry.subentry_id, new_unique_id=subentry.subentry_id, ) - # Remove options from the main entry - hass.config_entries.async_update_entry( - entry, - options={}, - version=2, + device = device_registry.async_get_device( + identifiers={(DOMAIN, entry.entry_id)} ) + if device is not None: + device_registry.async_update_device( + device.id, + new_identifiers={(DOMAIN, subentry.subentry_id)}, + add_config_subentry_id=subentry.subentry_id, + add_config_entry_id=parent_entry.entry_id, + ) + if parent_entry.entry_id != entry.entry_id: + device_registry.async_update_device( + device.id, + remove_config_entry_id=entry.entry_id, + ) - return True + if not use_existing: + await hass.config_entries.async_remove(entry.entry_id) + else: + hass.config_entries.async_update_entry( + entry, + options={}, + version=2, + ) diff --git a/homeassistant/components/ollama/config_flow.py b/homeassistant/components/ollama/config_flow.py index f4f504f1783..58b557549e1 100644 --- a/homeassistant/components/ollama/config_flow.py +++ b/homeassistant/components/ollama/config_flow.py @@ -14,6 +14,7 @@ import voluptuous as vol from homeassistant.config_entries import ( ConfigEntry, + ConfigEntryState, ConfigFlow, ConfigFlowResult, ConfigSubentryFlow, @@ -44,7 +45,6 @@ from .const import ( CONF_NUM_CTX, CONF_PROMPT, CONF_THINK, - DEFAULT_CONVERSATION_NAME, DEFAULT_KEEP_ALIVE, DEFAULT_MAX_HISTORY, DEFAULT_MODEL, @@ -96,6 +96,8 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): errors = {} + self._async_abort_entries_match({CONF_URL: self.url}) + try: self.client = ollama.AsyncClient( host=self.url, verify=get_default_context() @@ -148,13 +150,13 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): return await self.async_step_download() return self.async_create_entry( - title=_get_title(self.model), + title=self.url, data={CONF_URL: self.url, CONF_MODEL: self.model}, subentries=[ { "subentry_type": "conversation", "data": {}, - "title": DEFAULT_CONVERSATION_NAME, + "title": _get_title(self.model), "unique_id": None, } ], @@ -203,7 +205,7 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): { "subentry_type": "conversation", "data": {}, - "title": DEFAULT_CONVERSATION_NAME, + "title": _get_title(self.model), "unique_id": None, } ], @@ -236,6 +238,10 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow): self, user_input: dict[str, Any] | None = None ) -> SubentryFlowResult: """Set conversation options.""" + # abort if entry is not loaded + if self._get_entry().state != ConfigEntryState.LOADED: + return self.async_abort(reason="entry_not_loaded") + errors: dict[str, str] = {} if user_input is None: @@ -279,7 +285,7 @@ def ollama_config_option_schema( if is_new: schema: dict[vol.Required | vol.Optional, Any] = { - vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME): str, + vol.Required(CONF_NAME, default="Ollama Conversation"): str, } else: schema = {} diff --git a/homeassistant/components/ollama/const.py b/homeassistant/components/ollama/const.py index fa6749e983f..d734cc889f2 100644 --- a/homeassistant/components/ollama/const.py +++ b/homeassistant/components/ollama/const.py @@ -1,6 +1,11 @@ """Constants for the Ollama integration.""" +import ollama + +from homeassistant.config_entries import ConfigEntry + DOMAIN = "ollama" +type OllamaConfigEntry = ConfigEntry[ollama.AsyncClient] CONF_MODEL = "model" CONF_PROMPT = "prompt" @@ -157,5 +162,3 @@ MODEL_NAMES = [ # https://ollama.com/library "zephyr", ] DEFAULT_MODEL = "llama3.2:latest" - -DEFAULT_CONVERSATION_NAME = "Ollama Conversation" diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index b1e121352e5..99dc4b1c430 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -15,7 +15,7 @@ from homeassistant.config_entries import ConfigEntry, ConfigSubentry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import intent, llm +from homeassistant.helpers import device_registry as dr, intent, llm from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback from .const import ( @@ -188,6 +188,13 @@ class OllamaConversationEntity( self.subentry = subentry self._attr_name = subentry.title self._attr_unique_id = subentry.subentry_id + self._attr_device_info = dr.DeviceInfo( + identifiers={(DOMAIN, subentry.subentry_id)}, + name=subentry.title, + manufacturer="Ollama", + model=entry.data[CONF_MODEL], + entry_type=dr.DeviceEntryType.SERVICE, + ) if self.subentry.data.get(CONF_LLM_HASS_API): self._attr_supported_features = ( conversation.ConversationEntityFeature.CONTROL @@ -254,7 +261,7 @@ class OllamaConversationEntity( """Generate an answer for the chat log.""" settings = {**self.entry.data, **self.subentry.data} - client = self.hass.data[DOMAIN][self.entry.entry_id] + client = self.entry.runtime_data model = settings[CONF_MODEL] tools: list[dict[str, Any]] | None = None diff --git a/homeassistant/components/ollama/strings.json b/homeassistant/components/ollama/strings.json index 71ba8ccd8d2..74a5eaff454 100644 --- a/homeassistant/components/ollama/strings.json +++ b/homeassistant/components/ollama/strings.json @@ -12,7 +12,8 @@ } }, "abort": { - "download_failed": "Model downloading failed" + "download_failed": "Model downloading failed", + "already_configured": "[%key:common::config_flow::abort::already_configured_service%]" }, "error": { "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", @@ -49,7 +50,8 @@ } }, "abort": { - "reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]" + "reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]", + "entry_not_loaded": "Cannot add things while the configuration is disabled." } } } diff --git a/tests/components/ollama/test_config_flow.py b/tests/components/ollama/test_config_flow.py index 99083c58f85..4b78df9bce2 100644 --- a/tests/components/ollama/test_config_flow.py +++ b/tests/components/ollama/test_config_flow.py @@ -63,6 +63,37 @@ async def test_form(hass: HomeAssistant) -> None: assert len(mock_setup_entry.mock_calls) == 1 +async def test_duplicate_entry(hass: HomeAssistant) -> None: + """Test we abort on duplicate config entry.""" + MockConfigEntry( + domain=ollama.DOMAIN, + data={ + ollama.CONF_URL: "http://localhost:11434", + ollama.CONF_MODEL: "test_model", + }, + ).add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + ollama.DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + assert not result["errors"] + + with patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", + return_value={"models": [{"model": "test_model"}]}, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + ollama.CONF_URL: "http://localhost:11434", + }, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "already_configured" + + async def test_form_need_download(hass: HomeAssistant) -> None: """Test flow when a model needs to be downloaded.""" # Pretend we already set up a config entry. @@ -163,7 +194,7 @@ async def test_subentry_options( # Test reconfiguration options_flow = await mock_config_entry.start_subentry_reconfigure_flow( - hass, subentry.subentry_type, subentry.subentry_id + hass, subentry.subentry_id ) assert options_flow["type"] is FlowResultType.FORM @@ -190,6 +221,22 @@ async def test_subentry_options( } +async def test_creating_conversation_subentry_not_loaded( + hass: HomeAssistant, + mock_init_component, + mock_config_entry: MockConfigEntry, +) -> None: + """Test creating a conversation subentry when entry is not loaded.""" + await hass.config_entries.async_unload(mock_config_entry.entry_id) + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, "conversation"), + context={"source": config_entries.SOURCE_USER}, + ) + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "entry_not_loaded" + + @pytest.mark.parametrize( ("side_effect", "error"), [ diff --git a/tests/components/ollama/test_init.py b/tests/components/ollama/test_init.py index bfb7154801a..e11eb98451a 100644 --- a/tests/components/ollama/test_init.py +++ b/tests/components/ollama/test_init.py @@ -6,7 +6,7 @@ from httpx import ConnectError import pytest from homeassistant.components import ollama -from homeassistant.components.ollama.const import DEFAULT_CONVERSATION_NAME, DOMAIN +from homeassistant.components.ollama.const import DOMAIN from homeassistant.core import HomeAssistant from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.setup import async_setup_component @@ -56,11 +56,20 @@ async def test_migration_from_v1_to_v2( ) mock_config_entry.add_to_hass(hass) + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="Ollama", + model="Ollama", + entry_type=dr.DeviceEntryType.SERVICE, + ) entity = entity_registry.async_get_or_create( "conversation", DOMAIN, mock_config_entry.entry_id, config_entry=mock_config_entry, + device_id=device.id, suggested_object_id="llama_3_2_8b", ) @@ -79,7 +88,7 @@ async def test_migration_from_v1_to_v2( subentry = next(iter(mock_config_entry.subentries.values())) assert subentry.unique_id is None - assert subentry.title == DEFAULT_CONVERSATION_NAME + assert subentry.title == "llama-3.2-8b" assert subentry.subentry_type == "conversation" assert subentry.data == TEST_OPTIONS @@ -87,3 +96,192 @@ async def test_migration_from_v1_to_v2( assert migrated_entity is not None assert migrated_entity.config_entry_id == mock_config_entry.entry_id assert migrated_entity.config_subentry_id == subentry.subentry_id + assert migrated_entity.unique_id == subentry.subentry_id + + # Check device migration + assert not device_registry.async_get_device( + identifiers={(DOMAIN, mock_config_entry.entry_id)} + ) + assert ( + migrated_device := device_registry.async_get_device( + identifiers={(DOMAIN, subentry.subentry_id)} + ) + ) + assert migrated_device.identifiers == {(DOMAIN, subentry.subentry_id)} + assert migrated_device.id == device.id + + +async def test_migration_from_v1_to_v2_with_multiple_urls( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test migration from version 1 to version 2 with different URLs.""" + # Create two v1 config entries with different URLs + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, + options=TEST_OPTIONS, + version=1, + title="Ollama 1", + ) + mock_config_entry.add_to_hass(hass) + mock_config_entry_2 = MockConfigEntry( + domain=DOMAIN, + data={"url": "http://localhost:11435", "model": "llama3.2:latest"}, + options=TEST_OPTIONS, + version=1, + title="Ollama 2", + ) + mock_config_entry_2.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="Ollama", + model="Ollama 1", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry.entry_id, + config_entry=mock_config_entry, + device_id=device.id, + suggested_object_id="ollama_1", + ) + + device_2 = device_registry.async_get_or_create( + config_entry_id=mock_config_entry_2.entry_id, + identifiers={(DOMAIN, mock_config_entry_2.entry_id)}, + name=mock_config_entry_2.title, + manufacturer="Ollama", + model="Ollama 2", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry_2.entry_id, + config_entry=mock_config_entry_2, + device_id=device_2.id, + suggested_object_id="ollama_2", + ) + + # Run migration + with patch( + "homeassistant.components.ollama.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 2 + + for idx, entry in enumerate(entries): + assert entry.version == 2 + assert not entry.options + assert len(entry.subentries) == 1 + subentry = list(entry.subentries.values())[0] + assert subentry.subentry_type == "conversation" + assert subentry.data == TEST_OPTIONS + assert subentry.title == f"Ollama {idx + 1}" + + dev = device_registry.async_get_device( + identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)} + ) + assert dev is not None + + +async def test_migration_from_v1_to_v2_with_same_urls( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + entity_registry: er.EntityRegistry, +) -> None: + """Test migration from version 1 to version 2 with same URLs consolidates entries.""" + # Create two v1 config entries with the same URL + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, + options=TEST_OPTIONS, + version=1, + title="Ollama", + ) + mock_config_entry.add_to_hass(hass) + mock_config_entry_2 = MockConfigEntry( + domain=DOMAIN, + data={"url": "http://localhost:11434", "model": "llama3.2:latest"}, # Same URL + options=TEST_OPTIONS, + version=1, + title="Ollama 2", + ) + mock_config_entry_2.add_to_hass(hass) + + device = device_registry.async_get_or_create( + config_entry_id=mock_config_entry.entry_id, + identifiers={(DOMAIN, mock_config_entry.entry_id)}, + name=mock_config_entry.title, + manufacturer="Ollama", + model="Ollama", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry.entry_id, + config_entry=mock_config_entry, + device_id=device.id, + suggested_object_id="ollama", + ) + + device_2 = device_registry.async_get_or_create( + config_entry_id=mock_config_entry_2.entry_id, + identifiers={(DOMAIN, mock_config_entry_2.entry_id)}, + name=mock_config_entry_2.title, + manufacturer="Ollama", + model="Ollama", + entry_type=dr.DeviceEntryType.SERVICE, + ) + entity_registry.async_get_or_create( + "conversation", + DOMAIN, + mock_config_entry_2.entry_id, + config_entry=mock_config_entry_2, + device_id=device_2.id, + suggested_object_id="ollama_2", + ) + + # Run migration + with patch( + "homeassistant.components.ollama.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + # Should have only one entry left (consolidated) + entries = hass.config_entries.async_entries(DOMAIN) + assert len(entries) == 1 + + entry = entries[0] + assert entry.version == 2 + assert not entry.options + assert len(entry.subentries) == 2 # Two subentries from the two original entries + + # Check both subentries exist with correct data + subentries = list(entry.subentries.values()) + titles = [sub.title for sub in subentries] + assert "Ollama" in titles + assert "Ollama 2" in titles + + for subentry in subentries: + assert subentry.subentry_type == "conversation" + assert subentry.data == TEST_OPTIONS + + # Check devices were migrated correctly + dev = device_registry.async_get_device( + identifiers={(DOMAIN, subentry.subentry_id)} + ) + assert dev is not None