mirror of
https://github.com/home-assistant/core.git
synced 2025-06-25 01:21:51 +02:00
Add Google AI Task platform
This commit is contained in:
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from types import MappingProxyType
|
||||
|
||||
from google.genai import Client
|
||||
from google.genai.errors import APIError, ClientError
|
||||
@ -37,9 +38,11 @@ from homeassistant.helpers.typing import ConfigType
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_PROMPT,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DOMAIN,
|
||||
FILE_POLLING_INTERVAL_SECONDS,
|
||||
LOGGER,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
TIMEOUT_MILLIS,
|
||||
)
|
||||
@ -50,6 +53,7 @@ CONF_FILENAMES = "filenames"
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
PLATFORMS = (
|
||||
Platform.AI_TASK,
|
||||
Platform.CONVERSATION,
|
||||
Platform.TTS,
|
||||
)
|
||||
@ -280,3 +284,25 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
|
||||
options={},
|
||||
version=2,
|
||||
)
|
||||
|
||||
|
||||
async def async_migrate_entry(
|
||||
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
|
||||
) -> bool:
|
||||
"""Migrate old entry."""
|
||||
if entry.version == 2 and entry.minor_version == 0:
|
||||
# Migrate from version 2.0 to 2.1
|
||||
# Add AI Task subentry with default options
|
||||
hass.config_entries.async_add_subentry(
|
||||
entry,
|
||||
ConfigSubentry(
|
||||
data=MappingProxyType(RECOMMENDED_AI_TASK_OPTIONS),
|
||||
subentry_type="ai_task",
|
||||
title=DEFAULT_AI_TASK_NAME,
|
||||
unique_id=None,
|
||||
),
|
||||
)
|
||||
hass.config_entries.async_update_entry(entry, minor_version=1)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
@ -0,0 +1,57 @@
|
||||
"""AI Task integration for Google Generative AI Conversation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from .const import LOGGER
|
||||
from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up AI Task entities."""
|
||||
for subentry in config_entry.subentries.values():
|
||||
if subentry.subentry_type != "ai_task":
|
||||
continue
|
||||
|
||||
async_add_entities(
|
||||
[GoogleGenerativeAITaskEntity(config_entry, subentry)],
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
)
|
||||
|
||||
|
||||
class GoogleGenerativeAITaskEntity(
|
||||
ai_task.AITaskEntity,
|
||||
GoogleGenerativeAILLMBaseEntity,
|
||||
):
|
||||
"""Google Generative AI AI Task entity."""
|
||||
|
||||
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_TEXT
|
||||
|
||||
async def _async_generate_text(
|
||||
self,
|
||||
task: ai_task.GenTextTask,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenTextTaskResult:
|
||||
"""Handle a generate text task."""
|
||||
await self._async_handle_chat_log(chat_log)
|
||||
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
LOGGER.error(
|
||||
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
|
||||
chat_log.content[-1],
|
||||
)
|
||||
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
|
||||
|
||||
return ai_task.GenTextTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
text=chat_log.content[-1].content or "",
|
||||
)
|
@ -46,9 +46,12 @@ from .const import (
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
@ -66,12 +69,6 @@ STEP_API_DATA_SCHEMA = vol.Schema(
|
||||
}
|
||||
)
|
||||
|
||||
RECOMMENDED_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||
}
|
||||
|
||||
|
||||
async def validate_input(data: dict[str, Any]) -> None:
|
||||
"""Validate the user input allows us to connect.
|
||||
@ -123,10 +120,16 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
subentries=[
|
||||
{
|
||||
"subentry_type": "conversation",
|
||||
"data": RECOMMENDED_OPTIONS,
|
||||
"data": RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
"title": DEFAULT_CONVERSATION_NAME,
|
||||
"unique_id": None,
|
||||
}
|
||||
},
|
||||
{
|
||||
"subentry_type": "ai_task",
|
||||
"data": RECOMMENDED_AI_TASK_OPTIONS,
|
||||
"title": DEFAULT_AI_TASK_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
return self.async_show_form(
|
||||
@ -172,10 +175,13 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
cls, config_entry: ConfigEntry
|
||||
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||
"""Return subentries supported by this integration."""
|
||||
return {"conversation": ConversationSubentryFlowHandler}
|
||||
return {
|
||||
"conversation": LLMSubentryFlowHandler,
|
||||
"ai_task": LLMSubentryFlowHandler,
|
||||
}
|
||||
|
||||
|
||||
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
class LLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
"""Flow for managing conversation subentries."""
|
||||
|
||||
last_rendered_recommended = False
|
||||
@ -201,12 +207,14 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
if user_input is None:
|
||||
if self._is_new:
|
||||
options = RECOMMENDED_OPTIONS.copy()
|
||||
else:
|
||||
if not self._is_new:
|
||||
# If this is a reconfiguration, we need to copy the existing options
|
||||
# so that we can show the current values in the form.
|
||||
options = self._get_reconfigure_subentry().data.copy()
|
||||
elif self._subentry_type == "ai_task":
|
||||
options = RECOMMENDED_AI_TASK_OPTIONS.copy()
|
||||
else:
|
||||
options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||
|
||||
self.last_rendered_recommended = cast(
|
||||
bool, options.get(CONF_RECOMMENDED, False)
|
||||
@ -216,7 +224,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
||||
if not user_input.get(CONF_LLM_HASS_API):
|
||||
user_input.pop(CONF_LLM_HASS_API, None)
|
||||
# Don't allow to save options that enable the Google Seearch tool with an Assist API
|
||||
# Don't allow to save options that enable the Google Search tool with an Assist API
|
||||
if not (
|
||||
user_input.get(CONF_LLM_HASS_API)
|
||||
and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True
|
||||
@ -240,7 +248,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
options = user_input
|
||||
|
||||
schema = await google_generative_ai_config_option_schema(
|
||||
self.hass, self._is_new, options, self._genai_client
|
||||
self.hass, self._is_new, self._subentry_type, options, self._genai_client
|
||||
)
|
||||
return self.async_show_form(
|
||||
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
|
||||
@ -253,6 +261,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
async def google_generative_ai_config_option_schema(
|
||||
hass: HomeAssistant,
|
||||
is_new: bool,
|
||||
subentry_type: str,
|
||||
options: Mapping[str, Any],
|
||||
genai_client: genai.Client,
|
||||
) -> dict:
|
||||
@ -270,37 +279,56 @@ async def google_generative_ai_config_option_schema(
|
||||
suggested_llm_apis = [suggested_llm_apis]
|
||||
|
||||
if is_new:
|
||||
if CONF_NAME in options:
|
||||
default_name = options[CONF_NAME]
|
||||
elif subentry_type == "ai_task":
|
||||
default_name = DEFAULT_AI_TASK_NAME
|
||||
else:
|
||||
default_name = DEFAULT_CONVERSATION_NAME
|
||||
schema: dict[vol.Required | vol.Optional, Any] = {
|
||||
vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME): str,
|
||||
vol.Required(CONF_NAME, default=default_name): str,
|
||||
}
|
||||
else:
|
||||
schema = {}
|
||||
|
||||
schema.update(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={
|
||||
"suggested_value": options.get(
|
||||
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
||||
)
|
||||
},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": suggested_llm_apis},
|
||||
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
|
||||
vol.Required(
|
||||
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
if subentry_type == "conversation":
|
||||
schema.update(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={
|
||||
"suggested_value": options.get(
|
||||
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
||||
)
|
||||
},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": suggested_llm_apis},
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(options=hass_apis, multiple=True)
|
||||
),
|
||||
vol.Required(
|
||||
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# For ai_task and tts subentry types
|
||||
schema.update(
|
||||
{
|
||||
vol.Required(
|
||||
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
|
||||
if options.get(CONF_RECOMMENDED):
|
||||
return schema
|
||||
|
||||
api_models_pager = await genai_client.aio.models.list(config={"query_base": True})
|
||||
api_models = [api_model async for api_model in api_models_pager]
|
||||
|
||||
models = [
|
||||
SelectOptionDict(
|
||||
label=api_model.display_name,
|
||||
@ -396,13 +424,17 @@ async def google_generative_ai_config_option_schema(
|
||||
},
|
||||
default=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
): harm_block_thresholds_selector,
|
||||
}
|
||||
)
|
||||
if subentry_type == "conversation":
|
||||
schema[
|
||||
vol.Optional(
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
description={
|
||||
"suggested_value": options.get(CONF_USE_GOOGLE_SEARCH_TOOL),
|
||||
},
|
||||
default=RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
)
|
||||
] = bool
|
||||
|
||||
return schema
|
||||
|
@ -2,11 +2,16 @@
|
||||
|
||||
import logging
|
||||
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
DOMAIN = "google_generative_ai_conversation"
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
CONF_PROMPT = "prompt"
|
||||
|
||||
DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
|
||||
DEFAULT_AI_TASK_NAME = "Google AI Task"
|
||||
|
||||
|
||||
ATTR_MODEL = "model"
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
@ -31,3 +36,13 @@ RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
|
||||
|
||||
TIMEOUT_MILLIS = 10000
|
||||
FILE_POLLING_INTERVAL_SECONDS = 0.05
|
||||
|
||||
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||
}
|
||||
|
||||
RECOMMENDED_AI_TASK_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
}
|
||||
|
@ -61,6 +61,33 @@
|
||||
"error": {
|
||||
"invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting."
|
||||
}
|
||||
},
|
||||
"ai_task": {
|
||||
"initiate_flow": {
|
||||
"user": "Add Generate data with AI service",
|
||||
"reconfigure": "Reconfigure Generate data with AI service"
|
||||
},
|
||||
"entry_type": "Generate data with AI service",
|
||||
"step": {
|
||||
"set_options": {
|
||||
"data": {
|
||||
"name": "[%key:common::config_flow::data::name%]",
|
||||
"recommended": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::recommended%]",
|
||||
"chat_model": "[%key:common::generic::model%]",
|
||||
"temperature": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::temperature%]",
|
||||
"top_p": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_p%]",
|
||||
"top_k": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_k%]",
|
||||
"max_tokens": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::max_tokens%]",
|
||||
"harassment_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::harassment_block_threshold%]",
|
||||
"hate_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::hate_block_threshold%]",
|
||||
"sexual_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::sexual_block_threshold%]",
|
||||
"dangerous_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::dangerous_block_threshold%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
|
@ -3420,6 +3420,11 @@ class ConfigSubentryFlow(
|
||||
"""Return config entry id."""
|
||||
return self.handler[0]
|
||||
|
||||
@property
|
||||
def _subentry_type(self) -> str:
|
||||
"""Return type of subentry we are editing/creating."""
|
||||
return self.handler[1]
|
||||
|
||||
@callback
|
||||
def _get_entry(self) -> ConfigEntry:
|
||||
"""Return the config entry linked to the current context."""
|
||||
|
@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@ -28,13 +29,20 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
"api_key": "bla",
|
||||
},
|
||||
version=2,
|
||||
minor_version=1,
|
||||
subentries_data=[
|
||||
{
|
||||
"data": {},
|
||||
"subentry_type": "conversation",
|
||||
"title": DEFAULT_CONVERSATION_NAME,
|
||||
"unique_id": None,
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": {},
|
||||
"subentry_type": "ai_task",
|
||||
"title": DEFAULT_AI_TASK_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
entry.runtime_data = Mock()
|
||||
|
@ -0,0 +1,62 @@
|
||||
"""Test AI Task platform of Google Generative AI Conversation integration."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from google.genai.types import GenerateContentResponse
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import ai_task
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.components.conversation import (
|
||||
MockChatLog,
|
||||
mock_chat_log, # noqa: F401
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_run_task(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
mock_send_message_stream: AsyncMock,
|
||||
entity_registry: er.EntityRegistry,
|
||||
) -> None:
|
||||
"""Test empty response."""
|
||||
entity_id = "ai_task.google_ai_task"
|
||||
|
||||
# Ensure it's linked to the subentry
|
||||
entity_entry = entity_registry.async_get(entity_id)
|
||||
ai_task_entry = next(
|
||||
iter(
|
||||
entry
|
||||
for entry in mock_config_entry.subentries.values()
|
||||
if entry.subentry_type == "ai_task"
|
||||
)
|
||||
)
|
||||
assert entity_entry.config_entry_id == mock_config_entry.entry_id
|
||||
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
|
||||
|
||||
mock_send_message_stream.return_value = [
|
||||
[
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [{"text": "Hi there!"}],
|
||||
"role": "model",
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
]
|
||||
result = await ai_task.async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=entity_id,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
assert result.text == "Hi there!"
|
@ -6,9 +6,6 @@ import pytest
|
||||
from requests.exceptions import Timeout
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.google_generative_ai_conversation.config_flow import (
|
||||
RECOMMENDED_OPTIONS,
|
||||
)
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_DANGEROUS_BLOCK_THRESHOLD,
|
||||
@ -22,9 +19,12 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_TOP_K,
|
||||
@ -115,10 +115,16 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
assert result2["subentries"] == [
|
||||
{
|
||||
"subentry_type": "conversation",
|
||||
"data": RECOMMENDED_OPTIONS,
|
||||
"data": RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
"title": DEFAULT_CONVERSATION_NAME,
|
||||
"unique_id": None,
|
||||
}
|
||||
},
|
||||
{
|
||||
"subentry_type": "ai_task",
|
||||
"data": RECOMMENDED_AI_TASK_OPTIONS,
|
||||
"title": DEFAULT_AI_TASK_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
]
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
@ -172,26 +178,72 @@ async def test_creating_conversation_subentry(
|
||||
):
|
||||
result2 = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_NAME: "Mock name", **RECOMMENDED_OPTIONS},
|
||||
{CONF_NAME: "Mock name", **RECOMMENDED_CONVERSATION_OPTIONS},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Mock name"
|
||||
|
||||
processed_options = RECOMMENDED_OPTIONS.copy()
|
||||
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
|
||||
|
||||
assert result2["data"] == processed_options
|
||||
|
||||
|
||||
async def test_creating_ai_task_subentry(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component: None,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test creating an AI task subentry."""
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
):
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "ai_task"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.FORM, result
|
||||
assert result["step_id"] == "set_options"
|
||||
assert not result["errors"]
|
||||
|
||||
old_subentries = set(mock_config_entry.subentries)
|
||||
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
):
|
||||
result2 = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_NAME: "Mock AI Task", **RECOMMENDED_AI_TASK_OPTIONS},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Mock AI Task"
|
||||
assert result2["data"] == RECOMMENDED_AI_TASK_OPTIONS
|
||||
|
||||
assert len(mock_config_entry.subentries) == 3
|
||||
|
||||
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||
|
||||
assert new_subentry.subentry_type == "ai_task"
|
||||
assert new_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert new_subentry.title == "Mock AI Task"
|
||||
|
||||
|
||||
async def test_creating_conversation_subentry_not_loaded(
|
||||
hass: HomeAssistant,
|
||||
mock_init_component: None,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test creating a conversation subentry."""
|
||||
"""Test that subentry fails to init if entry not loaded."""
|
||||
await hass.config_entries.async_unload(mock_config_entry.entry_id)
|
||||
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
|
@ -7,7 +7,13 @@ import pytest
|
||||
from requests.exceptions import Timeout
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.google_generative_ai_conversation.const import DOMAIN
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.const import CONF_API_KEY
|
||||
from homeassistant.core import HomeAssistant
|
||||
@ -722,3 +728,70 @@ async def test_migration_from_v1_to_v2_with_same_keys(
|
||||
)
|
||||
assert device.identifiers == {(DOMAIN, subentry.subentry_id)}
|
||||
assert device.id == device_2.id
|
||||
|
||||
|
||||
async def test_migrate_entry_from_v2_0_to_v2_1(hass: HomeAssistant) -> None:
|
||||
"""Test migration from version 2.0 to version 2.1."""
|
||||
# Create a v2.0 config entry (subversion 0) with only conversation subentry
|
||||
mock_config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={CONF_API_KEY: "test-api-key"},
|
||||
version=2,
|
||||
minor_version=0,
|
||||
subentries_data=[
|
||||
{
|
||||
"data": RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
"subentry_type": "conversation",
|
||||
"title": DEFAULT_CONVERSATION_NAME,
|
||||
"unique_id": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
# Verify initial state
|
||||
assert mock_config_entry.version == 2
|
||||
assert mock_config_entry.minor_version == 0
|
||||
assert len(mock_config_entry.subentries) == 1
|
||||
assert (
|
||||
list(mock_config_entry.subentries.values())[0].subentry_type == "conversation"
|
||||
)
|
||||
|
||||
# Run setup to trigger migration
|
||||
with patch(
|
||||
"homeassistant.components.google_generative_ai_conversation.async_setup_entry",
|
||||
return_value=True,
|
||||
):
|
||||
result = await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
assert result is True
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Verify migration completed
|
||||
entries = hass.config_entries.async_entries(DOMAIN)
|
||||
assert len(entries) == 1
|
||||
entry = entries[0]
|
||||
|
||||
# Check version and subversion were updated
|
||||
assert entry.version == 2
|
||||
assert entry.minor_version == 1
|
||||
|
||||
# Check we now have both conversation and ai_task subentries
|
||||
assert len(entry.subentries) == 2
|
||||
|
||||
subentries = {
|
||||
subentry.subentry_type: subentry for subentry in entry.subentries.values()
|
||||
}
|
||||
assert "conversation" in subentries
|
||||
assert "ai_task" in subentries
|
||||
|
||||
# Find and verify the ai_task subentry
|
||||
ai_task_subentry = subentries["ai_task"]
|
||||
assert ai_task_subentry is not None
|
||||
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
|
||||
assert ai_task_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
|
||||
# Verify conversation subentry is still there and unchanged
|
||||
conversation_subentry = subentries["conversation"]
|
||||
assert conversation_subentry is not None
|
||||
assert conversation_subentry.title == DEFAULT_CONVERSATION_NAME
|
||||
assert conversation_subentry.data == RECOMMENDED_CONVERSATION_OPTIONS
|
||||
|
Reference in New Issue
Block a user