Migrate Google Gen AI to use subentries

This commit is contained in:
Paulus Schoutsen
2025-06-21 23:29:28 +00:00
parent c453eed32d
commit 818e86f16e
10 changed files with 352 additions and 114 deletions

View File

@@ -12,7 +12,7 @@ from google.genai.types import File, FileState
from requests.exceptions import Timeout from requests.exceptions import Timeout
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_API_KEY, Platform from homeassistant.const import CONF_API_KEY, Platform
from homeassistant.core import ( from homeassistant.core import (
HomeAssistant, HomeAssistant,
@@ -26,13 +26,14 @@ from homeassistant.exceptions import (
ConfigEntryNotReady, ConfigEntryNotReady,
HomeAssistantError, HomeAssistantError,
) )
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv, entity_registry as er
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
CONF_PROMPT, CONF_PROMPT,
DEFAULT_CONVERSATION_NAME,
DOMAIN, DOMAIN,
FILE_POLLING_INTERVAL_SECONDS, FILE_POLLING_INTERVAL_SECONDS,
LOGGER, LOGGER,
@@ -209,3 +210,42 @@ async def async_unload_entry(
return False return False
return True return True
async def async_migrate_entry(
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
) -> bool:
"""Migrate old entry."""
if entry.version == 1:
# Migrate from version 1 to version 2
# Move conversation-specific options to a subentry
subentry = ConfigSubentry(
data=entry.options,
subentry_type="conversation",
title=DEFAULT_CONVERSATION_NAME,
unique_id=None,
)
hass.config_entries.async_add_subentry(
entry,
subentry,
)
# Migrate conversation entity to be linked to subentry
ent_reg = er.async_get(hass)
for entity_entry in er.async_entries_for_config_entry(ent_reg, entry.entry_id):
if entity_entry.domain == Platform.CONVERSATION:
ent_reg.async_update_entity(
entity_entry.entity_id,
config_subentry_id=subentry.subentry_id,
new_unique_id=subentry.subentry_id,
)
break
# Remove options from the main entry
hass.config_entries.async_update_entry(
entry,
options={},
version=2,
)
return True

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import logging import logging
from types import MappingProxyType
from typing import Any from typing import Any
from google import genai from google import genai
@@ -17,10 +16,11 @@ from homeassistant.config_entries import (
ConfigEntry, ConfigEntry,
ConfigFlow, ConfigFlow,
ConfigFlowResult, ConfigFlowResult,
OptionsFlow, ConfigSubentryFlow,
SubentryFlowResult,
) )
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import llm from homeassistant.helpers import llm
from homeassistant.helpers.selector import ( from homeassistant.helpers.selector import (
NumberSelector, NumberSelector,
@@ -45,6 +45,7 @@ from .const import (
CONF_TOP_K, CONF_TOP_K,
CONF_TOP_P, CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL, CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_CONVERSATION_NAME,
DOMAIN, DOMAIN,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_HARM_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD,
@@ -66,7 +67,7 @@ STEP_API_DATA_SCHEMA = vol.Schema(
RECOMMENDED_OPTIONS = { RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST, CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT, CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
} }
@@ -90,7 +91,7 @@ async def validate_input(data: dict[str, Any]) -> None:
class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN): class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Google Generative AI Conversation.""" """Handle a config flow for Google Generative AI Conversation."""
VERSION = 1 VERSION = 2
async def async_step_api( async def async_step_api(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
@@ -117,7 +118,14 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_create_entry( return self.async_create_entry(
title="Google Generative AI", title="Google Generative AI",
data=user_input, data=user_input,
options=RECOMMENDED_OPTIONS, subentries=[
{
"subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
],
) )
return self.async_show_form( return self.async_show_form(
step_id="api", step_id="api",
@@ -156,58 +164,90 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
}, },
) )
@staticmethod @classmethod
def async_get_options_flow( @callback
config_entry: ConfigEntry, def async_get_supported_subentry_types(
) -> OptionsFlow: cls, config_entry: ConfigEntry
"""Create the options flow.""" ) -> dict[str, type[ConfigSubentryFlow]]:
return GoogleGenerativeAIOptionsFlow(config_entry) """Return subentries supported by this integration."""
return {"conversation": ConversationSubentryFlowHandler}
class GoogleGenerativeAIOptionsFlow(OptionsFlow): class ConversationSubentryFlowHandler(ConfigSubentryFlow):
"""Google Generative AI config flow options handler.""" """Flow for managing conversation subentries."""
def __init__(self, config_entry: ConfigEntry) -> None: last_rendered_recommended = False
"""Initialize options flow.""" is_new: bool
self.last_rendered_recommended = config_entry.options.get( start_data: dict[str, Any]
CONF_RECOMMENDED, False
)
self._genai_client = config_entry.runtime_data
async def async_step_init( @property
def _genai_client(self) -> genai.Client:
"""Return the Google Generative AI client."""
return self._get_entry().runtime_data
async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult: ) -> SubentryFlowResult:
"""Manage the options.""" """Add a subentry."""
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options self.is_new = True
self.start_data = RECOMMENDED_OPTIONS.copy()
return await self.async_step_set_options()
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Handle reconfiguration of a subentry."""
self.is_new = False
self.start_data = self._get_reconfigure_subentry().data.copy()
return await self.async_step_set_options()
async def async_step_set_options(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Set conversation options."""
options = self.start_data
errors: dict[str, str] = {} errors: dict[str, str] = {}
if user_input is not None: if user_input is not None:
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended: if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if not user_input.get(CONF_LLM_HASS_API): if not user_input.get(CONF_LLM_HASS_API):
user_input.pop(CONF_LLM_HASS_API, None) user_input.pop(CONF_LLM_HASS_API, None)
# Don't allow to save options that enable the Google Seearch tool with an Assist API
if not ( if not (
user_input.get(CONF_LLM_HASS_API) user_input.get(CONF_LLM_HASS_API)
and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True
): ):
# Don't allow to save options that enable the Google Seearch tool with an Assist API if self.is_new:
return self.async_create_entry(title="", data=user_input) return self.async_create_entry(
title=user_input.pop(CONF_NAME),
data=user_input,
)
return self.async_update_and_abort(
self._get_entry(),
self._get_reconfigure_subentry(),
data=user_input,
)
errors[CONF_USE_GOOGLE_SEARCH_TOOL] = "invalid_google_search_option" errors[CONF_USE_GOOGLE_SEARCH_TOOL] = "invalid_google_search_option"
# Re-render the options again, now with the recommended options shown/hidden # Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED] self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
options = user_input options = user_input
else:
self.last_rendered_recommended = options.get(CONF_RECOMMENDED, False)
schema = await google_generative_ai_config_option_schema( schema = await google_generative_ai_config_option_schema(
self.hass, options, self._genai_client self.hass, self.is_new, options, self._genai_client
) )
return self.async_show_form( return self.async_show_form(
step_id="init", data_schema=vol.Schema(schema), errors=errors step_id="set_options", data_schema=vol.Schema(schema), errors=errors
) )
async def google_generative_ai_config_option_schema( async def google_generative_ai_config_option_schema(
hass: HomeAssistant, hass: HomeAssistant,
is_new: bool,
options: Mapping[str, Any], options: Mapping[str, Any],
genai_client: genai.Client, genai_client: genai.Client,
) -> dict: ) -> dict:
@@ -224,7 +264,15 @@ async def google_generative_ai_config_option_schema(
): ):
suggested_llm_apis = [suggested_llm_apis] suggested_llm_apis = [suggested_llm_apis]
schema = { if is_new:
schema: dict[vol.Required | vol.Optional, Any] = {
vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME): str,
}
else:
schema = {}
schema.update(
{
vol.Optional( vol.Optional(
CONF_PROMPT, CONF_PROMPT,
description={ description={
@@ -241,6 +289,7 @@ async def google_generative_ai_config_option_schema(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool, ): bool,
} }
)
if options.get(CONF_RECOMMENDED): if options.get(CONF_RECOMMENDED):
return schema return schema

View File

@@ -6,6 +6,8 @@ DOMAIN = "google_generative_ai_conversation"
LOGGER = logging.getLogger(__package__) LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt" CONF_PROMPT = "prompt"
DEFAULT_CONVERSATION_NAME = "Google Conversation"
ATTR_MODEL = "model" ATTR_MODEL = "model"
CONF_RECOMMENDED = "recommended" CONF_RECOMMENDED = "recommended"
CONF_CHAT_MODEL = "chat_model" CONF_CHAT_MODEL = "chat_model"

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
from typing import Literal from typing import Literal
from homeassistant.components import assist_pipeline, conversation from homeassistant.components import assist_pipeline, conversation
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -22,8 +22,14 @@ async def async_setup_entry(
async_add_entities: AddConfigEntryEntitiesCallback, async_add_entities: AddConfigEntryEntitiesCallback,
) -> None: ) -> None:
"""Set up conversation entities.""" """Set up conversation entities."""
agent = GoogleGenerativeAIConversationEntity(config_entry) for subentry in config_entry.subentries.values():
async_add_entities([agent]) if subentry.subentry_type != "conversation":
continue
async_add_entities(
[GoogleGenerativeAIConversationEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
class GoogleGenerativeAIConversationEntity( class GoogleGenerativeAIConversationEntity(
@@ -35,10 +41,10 @@ class GoogleGenerativeAIConversationEntity(
_attr_supports_streaming = True _attr_supports_streaming = True
def __init__(self, entry: ConfigEntry) -> None: def __init__(self, entry: ConfigEntry, subentry: ConfigSubentry) -> None:
"""Initialize the agent.""" """Initialize the agent."""
super().__init__(entry) super().__init__(entry, subentry)
if self.entry.options.get(CONF_LLM_HASS_API): if self.subentry.data.get(CONF_LLM_HASS_API):
self._attr_supported_features = ( self._attr_supported_features = (
conversation.ConversationEntityFeature.CONTROL conversation.ConversationEntityFeature.CONTROL
) )
@@ -70,7 +76,7 @@ class GoogleGenerativeAIConversationEntity(
chat_log: conversation.ChatLog, chat_log: conversation.ChatLog,
) -> conversation.ConversationResult: ) -> conversation.ConversationResult:
"""Call the API.""" """Call the API."""
options = self.entry.options options = self.subentry.data
try: try:
await chat_log.async_provide_llm_data( await chat_log.async_provide_llm_data(

View File

@@ -24,7 +24,7 @@ from google.genai.types import (
from voluptuous_openapi import convert from voluptuous_openapi import convert
from homeassistant.components import conversation from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm from homeassistant.helpers import device_registry as dr, llm
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
@@ -40,6 +40,7 @@ from .const import (
CONF_TOP_K, CONF_TOP_K,
CONF_TOP_P, CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL, CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_CONVERSATION_NAME,
DOMAIN, DOMAIN,
LOGGER, LOGGER,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
@@ -301,14 +302,13 @@ async def _transform_stream(
class GoogleGenerativeAILLMBaseEntity(Entity): class GoogleGenerativeAILLMBaseEntity(Entity):
"""Google Generative AI base entity.""" """Google Generative AI base entity."""
_attr_has_entity_name = True def __init__(self, entry: ConfigEntry, subentry: ConfigSubentry) -> None:
_attr_name = None
def __init__(self, entry: ConfigEntry) -> None:
"""Initialize the agent.""" """Initialize the agent."""
self.entry = entry self.entry = entry
self.subentry = subentry
self._attr_name = subentry.title or DEFAULT_CONVERSATION_NAME
self._genai_client = entry.runtime_data self._genai_client = entry.runtime_data
self._attr_unique_id = entry.entry_id self._attr_unique_id = subentry.subentry_id
self._attr_device_info = dr.DeviceInfo( self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)}, identifiers={(DOMAIN, entry.entry_id)},
name=entry.title, name=entry.title,
@@ -322,7 +322,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
chat_log: conversation.ChatLog, chat_log: conversation.ChatLog,
) -> None: ) -> None:
"""Generate an answer for the chat log.""" """Generate an answer for the chat log."""
options = self.entry.options options = self.subentry.data
tools: list[Tool | Callable[..., Any]] | None = None tools: list[Tool | Callable[..., Any]] | None = None
if chat_log.llm_api: if chat_log.llm_api:

View File

@@ -21,10 +21,18 @@
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]" "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
} }
}, },
"options": { "config_subentries": {
"conversation": {
"initiate_flow": {
"user": "Add conversation agent",
"reconfigure": "Reconfigure conversation agent"
},
"entry_type": "Conversation agent",
"step": { "step": {
"init": { "set_options": {
"data": { "data": {
"name": "[%key:common::config_flow::data::name%]",
"recommended": "Recommended model settings", "recommended": "Recommended model settings",
"prompt": "Instructions", "prompt": "Instructions",
"chat_model": "[%key:common::generic::model%]", "chat_model": "[%key:common::generic::model%]",
@@ -48,6 +56,7 @@
"error": { "error": {
"invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting." "invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting."
} }
}
}, },
"services": { "services": {
"generate_content": { "generate_content": {

View File

@@ -5,8 +5,9 @@ from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from homeassistant.components.google_generative_ai_conversation.entity import ( from homeassistant.components.google_generative_ai_conversation.const import (
CONF_USE_GOOGLE_SEARCH_TOOL, CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_CONVERSATION_NAME,
) )
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API
@@ -26,6 +27,15 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
data={ data={
"api_key": "bla", "api_key": "bla",
}, },
version=2,
subentries_data=[
{
"data": {},
"subentry_type": "conversation",
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
],
) )
entry.runtime_data = Mock() entry.runtime_data = Mock()
entry.add_to_hass(hass) entry.add_to_hass(hass)
@@ -38,8 +48,10 @@ async def mock_config_entry_with_assist(
) -> MockConfigEntry: ) -> MockConfigEntry:
"""Mock a config entry with assist.""" """Mock a config entry with assist."""
with patch("google.genai.models.AsyncModels.get"): with patch("google.genai.models.AsyncModels.get"):
hass.config_entries.async_update_entry( hass.config_entries.async_update_subentry(
mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} mock_config_entry,
next(iter(mock_config_entry.subentries.values())),
data={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
) )
await hass.async_block_till_done() await hass.async_block_till_done()
return mock_config_entry return mock_config_entry
@@ -51,9 +63,10 @@ async def mock_config_entry_with_google_search(
) -> MockConfigEntry: ) -> MockConfigEntry:
"""Mock a config entry with assist.""" """Mock a config entry with assist."""
with patch("google.genai.models.AsyncModels.get"): with patch("google.genai.models.AsyncModels.get"):
hass.config_entries.async_update_entry( hass.config_entries.async_update_subentry(
mock_config_entry, mock_config_entry,
options={ next(iter(mock_config_entry.subentries.values())),
data={
CONF_LLM_HASS_API: llm.LLM_API_ASSIST, CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_USE_GOOGLE_SEARCH_TOOL: True, CONF_USE_GOOGLE_SEARCH_TOOL: True,
}, },

View File

@@ -30,7 +30,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
RECOMMENDED_TOP_P, RECOMMENDED_TOP_P,
RECOMMENDED_USE_GOOGLE_SEARCH_TOOL, RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
) )
from homeassistant.const import CONF_LLM_HASS_API from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
@@ -110,10 +110,58 @@ async def test_form(hass: HomeAssistant) -> None:
assert result2["data"] == { assert result2["data"] == {
"api_key": "bla", "api_key": "bla",
} }
assert result2["options"] == RECOMMENDED_OPTIONS assert result2["options"] == {}
assert result2["subentries"] == [
{
"subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS,
"title": "Google Conversation",
"unique_id": None,
}
]
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
async def test_creating_conversation_subentry(
hass: HomeAssistant,
mock_init_component: None,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating a conversation subentry."""
mock_config_entry.add_to_hass(hass)
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "set_options"
assert not result["errors"]
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{CONF_NAME: "Mock name", **RECOMMENDED_OPTIONS},
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "Mock name"
processed_options = RECOMMENDED_OPTIONS.copy()
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
assert result2["data"] == processed_options
def will_options_be_rendered_again(current_options, new_options) -> bool: def will_options_be_rendered_again(current_options, new_options) -> bool:
"""Determine if options will be rendered again.""" """Determine if options will be rendered again."""
return current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED) return current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED)
@@ -283,7 +331,7 @@ def will_options_be_rendered_again(current_options, new_options) -> bool:
], ],
) )
@pytest.mark.usefixtures("mock_init_component") @pytest.mark.usefixtures("mock_init_component")
async def test_options_switching( async def test_subentry_options_switching(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,
current_options, current_options,
@@ -292,17 +340,18 @@ async def test_options_switching(
errors, errors,
) -> None: ) -> None:
"""Test the options form.""" """Test the options form."""
subentry = next(iter(mock_config_entry.subentries.values()))
with patch("google.genai.models.AsyncModels.get"): with patch("google.genai.models.AsyncModels.get"):
hass.config_entries.async_update_entry( hass.config_entries.async_update_subentry(
mock_config_entry, options=current_options mock_config_entry, subentry, data=current_options
) )
await hass.async_block_till_done() await hass.async_block_till_done()
with patch( with patch(
"google.genai.models.AsyncModels.list", "google.genai.models.AsyncModels.list",
return_value=get_models_pager(), return_value=get_models_pager(),
): ):
options_flow = await hass.config_entries.options.async_init( options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
mock_config_entry.entry_id hass, subentry.subentry_type, subentry.subentry_id
) )
if will_options_be_rendered_again(current_options, new_options): if will_options_be_rendered_again(current_options, new_options):
retry_options = { retry_options = {
@@ -313,7 +362,7 @@ async def test_options_switching(
"google.genai.models.AsyncModels.list", "google.genai.models.AsyncModels.list",
return_value=get_models_pager(), return_value=get_models_pager(),
): ):
options_flow = await hass.config_entries.options.async_configure( options_flow = await hass.config_entries.subentries.async_configure(
options_flow["flow_id"], options_flow["flow_id"],
retry_options, retry_options,
) )
@@ -321,14 +370,15 @@ async def test_options_switching(
"google.genai.models.AsyncModels.list", "google.genai.models.AsyncModels.list",
return_value=get_models_pager(), return_value=get_models_pager(),
): ):
options = await hass.config_entries.options.async_configure( options = await hass.config_entries.subentries.async_configure(
options_flow["flow_id"], options_flow["flow_id"],
new_options, new_options,
) )
await hass.async_block_till_done() await hass.async_block_till_done()
if errors is None: if errors is None:
assert options["type"] is FlowResultType.CREATE_ENTRY assert options["type"] is FlowResultType.ABORT
assert options["data"] == expected_options assert options["reason"] == "reconfigure_successful"
assert subentry.data == expected_options
else: else:
assert options["type"] is FlowResultType.FORM assert options["type"] is FlowResultType.FORM

View File

@@ -64,7 +64,7 @@ async def test_error_handling(
"hello", "hello",
None, None,
Context(), Context(),
agent_id="conversation.google_generative_ai_conversation", agent_id="conversation.google_conversation",
) )
assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result assert result.response.error_code == "unknown", result
@@ -82,7 +82,7 @@ async def test_function_call(
mock_send_message_stream: AsyncMock, mock_send_message_stream: AsyncMock,
) -> None: ) -> None:
"""Test function calling.""" """Test function calling."""
agent_id = "conversation.google_generative_ai_conversation" agent_id = "conversation.google_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -212,7 +212,7 @@ async def test_google_search_tool_is_sent(
mock_send_message_stream: AsyncMock, mock_send_message_stream: AsyncMock,
) -> None: ) -> None:
"""Test if the Google Search tool is sent to the model.""" """Test if the Google Search tool is sent to the model."""
agent_id = "conversation.google_generative_ai_conversation" agent_id = "conversation.google_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -278,7 +278,7 @@ async def test_blocked_response(
mock_send_message_stream: AsyncMock, mock_send_message_stream: AsyncMock,
) -> None: ) -> None:
"""Test blocked response.""" """Test blocked response."""
agent_id = "conversation.google_generative_ai_conversation" agent_id = "conversation.google_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -328,7 +328,7 @@ async def test_empty_response(
) -> None: ) -> None:
"""Test empty response.""" """Test empty response."""
agent_id = "conversation.google_generative_ai_conversation" agent_id = "conversation.google_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -371,7 +371,7 @@ async def test_none_response(
mock_send_message_stream: AsyncMock, mock_send_message_stream: AsyncMock,
) -> None: ) -> None:
"""Test None response.""" """Test None response."""
agent_id = "conversation.google_generative_ai_conversation" agent_id = "conversation.google_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -403,10 +403,12 @@ async def test_converse_error(
hass: HomeAssistant, mock_config_entry: MockConfigEntry hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None: ) -> None:
"""Test handling ChatLog raising ConverseError.""" """Test handling ChatLog raising ConverseError."""
subentry = next(iter(mock_config_entry.subentries.values()))
with patch("google.genai.models.AsyncModels.get"): with patch("google.genai.models.AsyncModels.get"):
hass.config_entries.async_update_entry( hass.config_entries.async_update_subentry(
mock_config_entry, mock_config_entry,
options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"}, next(iter(mock_config_entry.subentries.values())),
data={**subentry.data, CONF_LLM_HASS_API: "invalid_llm_api"},
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@@ -415,7 +417,7 @@ async def test_converse_error(
"hello", "hello",
None, None,
Context(), Context(),
agent_id="conversation.google_generative_ai_conversation", agent_id="conversation.google_conversation",
) )
assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.response_type == intent.IntentResponseType.ERROR, result
@@ -593,7 +595,7 @@ async def test_empty_content_in_chat_history(
mock_send_message_stream: AsyncMock, mock_send_message_stream: AsyncMock,
) -> None: ) -> None:
"""Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead.""" """Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead."""
agent_id = "conversation.google_generative_ai_conversation" agent_id = "conversation.google_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -648,7 +650,7 @@ async def test_history_always_user_first_turn(
) -> None: ) -> None:
"""Test that the user is always first in the chat history.""" """Test that the user is always first in the chat history."""
agent_id = "conversation.google_generative_ai_conversation" agent_id = "conversation.google_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -674,7 +676,7 @@ async def test_history_always_user_first_turn(
mock_chat_log.async_add_assistant_content_without_tools( mock_chat_log.async_add_assistant_content_without_tools(
conversation.AssistantContent( conversation.AssistantContent(
agent_id="conversation.google_generative_ai_conversation", agent_id="conversation.google_conversation",
content="Garage door left open, do you want to close it?", content="Garage door left open, do you want to close it?",
) )
) )

View File

@@ -7,9 +7,14 @@ import pytest
from requests.exceptions import Timeout from requests.exceptions import Timeout
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components.google_generative_ai_conversation import (
async_migrate_entry,
)
from homeassistant.components.google_generative_ai_conversation.const import DOMAIN
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, entity_registry as er
from . import API_ERROR_500, CLIENT_ERROR_API_KEY_INVALID from . import API_ERROR_500, CLIENT_ERROR_API_KEY_INVALID
@@ -387,3 +392,65 @@ async def test_load_entry_with_unloaded_entries(
"text": stubbed_generated_content, "text": stubbed_generated_content,
} }
assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot assert [tuple(mock_call) for mock_call in mock_generate.mock_calls] == snapshot
async def test_migration_from_v1_to_v2(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test migration from version 1 to version 2."""
# Create a v1 config entry with conversation options and an entity
OPTIONS = {
"recommended": True,
"llm_hass_api": ["assist"],
"prompt": "You are a helpful assistant",
"chat_model": "models/gemini-2.0-flash",
}
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={"api_key": "1234"},
options=OPTIONS,
version=1,
title="Google Generative AI",
)
mock_config_entry.add_to_hass(hass)
device = device_registry.async_get_or_create(
config_entry_id=mock_config_entry.entry_id,
identifiers={(DOMAIN, mock_config_entry.entry_id)},
name=mock_config_entry.title,
manufacturer="Google",
model="Generative AI",
entry_type=dr.DeviceEntryType.SERVICE,
)
entity = entity_registry.async_get_or_create(
"conversation",
DOMAIN,
"mock_config_entry.entry_id",
config_entry=mock_config_entry,
device_id=device.id,
suggested_object_id="google_generative_ai_conversation",
)
# Run migration
result = await async_migrate_entry(hass, mock_config_entry)
assert result is True
assert mock_config_entry.version == 2
assert mock_config_entry.data == {"api_key": "1234"}
assert mock_config_entry.options == {}
assert len(mock_config_entry.subentries) == 1
subentry = next(iter(mock_config_entry.subentries.values()))
assert subentry.unique_id is None
assert subentry.title == "Google Conversation"
assert subentry.subentry_type == "conversation"
assert subentry.data == OPTIONS
migrated_entity = entity_registry.async_get(entity.entity_id)
assert migrated_entity is not None
assert migrated_entity.config_entry_id == mock_config_entry.entry_id
assert migrated_entity.config_subentry_id == subentry.subentry_id
assert migrated_entity.device_id == device.id