forked from home-assistant/core
Compare commits
2 Commits
google-sub
...
google-llm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00baa90c78 | ||
|
|
3db5e1b551 |
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from types import MappingProxyType
|
||||
|
||||
from google.genai import Client
|
||||
from google.genai.errors import APIError, ClientError
|
||||
@@ -33,10 +34,12 @@ from homeassistant.helpers.typing import ConfigType
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_PROMPT,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DOMAIN,
|
||||
FILE_POLLING_INTERVAL_SECONDS,
|
||||
LOGGER,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
TIMEOUT_MILLIS,
|
||||
)
|
||||
@@ -47,6 +50,7 @@ CONF_FILENAMES = "filenames"
|
||||
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
PLATFORMS = (
|
||||
Platform.AI_TASK,
|
||||
Platform.CONVERSATION,
|
||||
Platform.TTS,
|
||||
)
|
||||
@@ -219,7 +223,7 @@ async def async_migrate_entry(
|
||||
if entry.version == 1:
|
||||
# Migrate from version 1 to version 2
|
||||
# Move conversation-specific options to a subentry
|
||||
subentry = ConfigSubentry(
|
||||
conversation_subentry = ConfigSubentry(
|
||||
data=entry.options,
|
||||
subentry_type="conversation",
|
||||
title=DEFAULT_CONVERSATION_NAME,
|
||||
@@ -227,7 +231,16 @@ async def async_migrate_entry(
|
||||
)
|
||||
hass.config_entries.async_add_subentry(
|
||||
entry,
|
||||
subentry,
|
||||
conversation_subentry,
|
||||
)
|
||||
hass.config_entries.async_add_subentry(
|
||||
entry,
|
||||
ConfigSubentry(
|
||||
data=MappingProxyType(RECOMMENDED_AI_TASK_OPTIONS),
|
||||
subentry_type="ai_task",
|
||||
title=DEFAULT_AI_TASK_NAME,
|
||||
unique_id=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Migrate conversation entity to be linked to subentry
|
||||
@@ -236,8 +249,8 @@ async def async_migrate_entry(
|
||||
if entity_entry.domain == Platform.CONVERSATION:
|
||||
ent_reg.async_update_entity(
|
||||
entity_entry.entity_id,
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
new_unique_id=subentry.subentry_id,
|
||||
config_subentry_id=conversation_subentry.subentry_id,
|
||||
new_unique_id=conversation_subentry.subentry_id,
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
"""LLM Task integration for Google Generative AI Conversation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.components import ai_task, conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from .const import LOGGER
|
||||
from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up LLM Task entities."""
|
||||
entities = [
|
||||
GoogleGenerativeAILLMTaskEntity(config_entry, subentry)
|
||||
for subentry in config_entry.subentries.values()
|
||||
if subentry.subentry_type == "ai_task"
|
||||
]
|
||||
async_add_entities(entities)
|
||||
|
||||
|
||||
class GoogleGenerativeAILLMTaskEntity(
|
||||
ai_task.AITaskEntity,
|
||||
GoogleGenerativeAILLMBaseEntity,
|
||||
):
|
||||
"""Google Generative AI AI Task entity."""
|
||||
|
||||
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_TEXT
|
||||
|
||||
async def _async_generate_text(
|
||||
self,
|
||||
task: ai_task.GenTextTask,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> ai_task.GenTextTaskResult:
|
||||
"""Handle a generate text task."""
|
||||
await self._async_handle_chat_log(chat_log)
|
||||
|
||||
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
|
||||
LOGGER.error(
|
||||
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
|
||||
chat_log.content[-1],
|
||||
)
|
||||
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
|
||||
|
||||
return ai_task.GenTextTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
text=chat_log.content[-1].content or "",
|
||||
)
|
||||
@@ -45,9 +45,12 @@ from .const import (
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_TEMPERATURE,
|
||||
@@ -65,12 +68,6 @@ STEP_API_DATA_SCHEMA = vol.Schema(
|
||||
}
|
||||
)
|
||||
|
||||
RECOMMENDED_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||
}
|
||||
|
||||
|
||||
async def validate_input(data: dict[str, Any]) -> None:
|
||||
"""Validate the user input allows us to connect.
|
||||
@@ -121,10 +118,16 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
subentries=[
|
||||
{
|
||||
"subentry_type": "conversation",
|
||||
"data": RECOMMENDED_OPTIONS,
|
||||
"data": RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
"title": DEFAULT_CONVERSATION_NAME,
|
||||
"unique_id": None,
|
||||
}
|
||||
},
|
||||
{
|
||||
"subentry_type": "ai_task",
|
||||
"data": RECOMMENDED_AI_TASK_OPTIONS,
|
||||
"title": DEFAULT_AI_TASK_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
return self.async_show_form(
|
||||
@@ -170,10 +173,13 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
cls, config_entry: ConfigEntry
|
||||
) -> dict[str, type[ConfigSubentryFlow]]:
|
||||
"""Return subentries supported by this integration."""
|
||||
return {"conversation": ConversationSubentryFlowHandler}
|
||||
return {
|
||||
"conversation": LLMSubentryFlowHandler,
|
||||
"ai_task": LLMSubentryFlowHandler,
|
||||
}
|
||||
|
||||
|
||||
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
class LLMSubentryFlowHandler(ConfigSubentryFlow):
|
||||
"""Flow for managing conversation subentries."""
|
||||
|
||||
last_rendered_recommended = False
|
||||
@@ -190,7 +196,10 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
) -> SubentryFlowResult:
|
||||
"""Add a subentry."""
|
||||
self.is_new = True
|
||||
self.start_data = RECOMMENDED_OPTIONS.copy()
|
||||
if self._subentry_type == "ai_task":
|
||||
self.start_data = RECOMMENDED_AI_TASK_OPTIONS.copy()
|
||||
else:
|
||||
self.start_data = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||
return await self.async_step_set_options()
|
||||
|
||||
async def async_step_reconfigure(
|
||||
@@ -212,7 +221,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
||||
if not user_input.get(CONF_LLM_HASS_API):
|
||||
user_input.pop(CONF_LLM_HASS_API, None)
|
||||
# Don't allow to save options that enable the Google Seearch tool with an Assist API
|
||||
# Don't allow to save options that enable the Google Search tool with an Assist API
|
||||
if not (
|
||||
user_input.get(CONF_LLM_HASS_API)
|
||||
and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True
|
||||
@@ -238,7 +247,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
self.last_rendered_recommended = options.get(CONF_RECOMMENDED, False)
|
||||
|
||||
schema = await google_generative_ai_config_option_schema(
|
||||
self.hass, self.is_new, options, self._genai_client
|
||||
self.hass, self.is_new, self._subentry_type, options, self._genai_client
|
||||
)
|
||||
return self.async_show_form(
|
||||
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
|
||||
@@ -248,6 +257,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
async def google_generative_ai_config_option_schema(
|
||||
hass: HomeAssistant,
|
||||
is_new: bool,
|
||||
subentry_type: str,
|
||||
options: Mapping[str, Any],
|
||||
genai_client: genai.Client,
|
||||
) -> dict:
|
||||
@@ -265,37 +275,56 @@ async def google_generative_ai_config_option_schema(
|
||||
suggested_llm_apis = [suggested_llm_apis]
|
||||
|
||||
if is_new:
|
||||
if CONF_NAME in options:
|
||||
default_name = options[CONF_NAME]
|
||||
elif subentry_type == "ai_task":
|
||||
default_name = DEFAULT_AI_TASK_NAME
|
||||
else:
|
||||
default_name = DEFAULT_CONVERSATION_NAME
|
||||
schema: dict[vol.Required | vol.Optional, Any] = {
|
||||
vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME): str,
|
||||
vol.Required(CONF_NAME, default=default_name): str,
|
||||
}
|
||||
else:
|
||||
schema = {}
|
||||
|
||||
schema.update(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={
|
||||
"suggested_value": options.get(
|
||||
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
||||
)
|
||||
},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": suggested_llm_apis},
|
||||
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
|
||||
vol.Required(
|
||||
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
if subentry_type == "conversation":
|
||||
schema.update(
|
||||
{
|
||||
vol.Optional(
|
||||
CONF_PROMPT,
|
||||
description={
|
||||
"suggested_value": options.get(
|
||||
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
|
||||
)
|
||||
},
|
||||
): TemplateSelector(),
|
||||
vol.Optional(
|
||||
CONF_LLM_HASS_API,
|
||||
description={"suggested_value": suggested_llm_apis},
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(options=hass_apis, multiple=True)
|
||||
),
|
||||
vol.Required(
|
||||
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# For ai_task and tts subentry types
|
||||
schema.update(
|
||||
{
|
||||
vol.Required(
|
||||
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
|
||||
if options.get(CONF_RECOMMENDED):
|
||||
return schema
|
||||
|
||||
api_models_pager = await genai_client.aio.models.list(config={"query_base": True})
|
||||
api_models = [api_model async for api_model in api_models_pager]
|
||||
|
||||
models = [
|
||||
SelectOptionDict(
|
||||
label=api_model.display_name,
|
||||
@@ -391,13 +420,17 @@ async def google_generative_ai_config_option_schema(
|
||||
},
|
||||
default=RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
): harm_block_thresholds_selector,
|
||||
}
|
||||
)
|
||||
if subentry_type == "conversation":
|
||||
schema[
|
||||
vol.Optional(
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
description={
|
||||
"suggested_value": options.get(CONF_USE_GOOGLE_SEARCH_TOOL),
|
||||
},
|
||||
default=RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
|
||||
): bool,
|
||||
}
|
||||
)
|
||||
)
|
||||
] = bool
|
||||
|
||||
return schema
|
||||
|
||||
@@ -2,11 +2,16 @@
|
||||
|
||||
import logging
|
||||
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
DOMAIN = "google_generative_ai_conversation"
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
CONF_PROMPT = "prompt"
|
||||
|
||||
DEFAULT_CONVERSATION_NAME = "Google Conversation"
|
||||
DEFAULT_AI_TASK_NAME = "Google AI Task"
|
||||
|
||||
|
||||
ATTR_MODEL = "model"
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
@@ -31,3 +36,13 @@ RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
|
||||
|
||||
TIMEOUT_MILLIS = 10000
|
||||
FILE_POLLING_INTERVAL_SECONDS = 0.05
|
||||
|
||||
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||
}
|
||||
|
||||
RECOMMENDED_AI_TASK_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
}
|
||||
|
||||
@@ -59,6 +59,33 @@
|
||||
"error": {
|
||||
"invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting."
|
||||
}
|
||||
},
|
||||
"ai_task": {
|
||||
"initiate_flow": {
|
||||
"user": "Add AI task service",
|
||||
"reconfigure": "Reconfigure AI task service"
|
||||
},
|
||||
"entry_type": "AI task service",
|
||||
"step": {
|
||||
"set_options": {
|
||||
"data": {
|
||||
"name": "[%key:common::config_flow::data::name%]",
|
||||
"recommended": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::recommended%]",
|
||||
"chat_model": "[%key:common::generic::model%]",
|
||||
"temperature": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::temperature%]",
|
||||
"top_p": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_p%]",
|
||||
"top_k": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_k%]",
|
||||
"max_tokens": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::max_tokens%]",
|
||||
"harassment_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::harassment_block_threshold%]",
|
||||
"hate_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::hate_block_threshold%]",
|
||||
"sexual_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::sexual_block_threshold%]",
|
||||
"dangerous_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::dangerous_block_threshold%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"services": {
|
||||
|
||||
@@ -3420,6 +3420,11 @@ class ConfigSubentryFlow(
|
||||
"""Return config entry id."""
|
||||
return self.handler[0]
|
||||
|
||||
@property
|
||||
def _subentry_type(self) -> str:
|
||||
"""Return type of subentry we are editing/creating."""
|
||||
return self.handler[1]
|
||||
|
||||
@callback
|
||||
def _get_entry(self) -> ConfigEntry:
|
||||
"""Return the config entry linked to the current context."""
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
@@ -34,7 +35,13 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
"subentry_type": "conversation",
|
||||
"title": DEFAULT_CONVERSATION_NAME,
|
||||
"unique_id": None,
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": {},
|
||||
"subentry_type": "ai_task",
|
||||
"title": DEFAULT_AI_TASK_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
entry.runtime_data = Mock()
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Test AI Task platform of Google Generative AI Conversation integration."""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from google.genai.types import GenerateContentResponse
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import ai_task
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.components.conversation import (
|
||||
MockChatLog,
|
||||
mock_chat_log, # noqa: F401
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("mock_init_component")
|
||||
async def test_run_task(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
mock_send_message_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test empty response."""
|
||||
entity_id = "ai_task.google_ai_task"
|
||||
mock_send_message_stream.return_value = [
|
||||
[
|
||||
GenerateContentResponse(
|
||||
candidates=[
|
||||
{
|
||||
"content": {
|
||||
"parts": [{"text": "Hi there!"}],
|
||||
"role": "model",
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
]
|
||||
result = await ai_task.async_generate_text(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=entity_id,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
assert result.text == "Hi there!"
|
||||
@@ -6,9 +6,6 @@ import pytest
|
||||
from requests.exceptions import Timeout
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.google_generative_ai_conversation.config_flow import (
|
||||
RECOMMENDED_OPTIONS,
|
||||
)
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_CHAT_MODEL,
|
||||
CONF_DANGEROUS_BLOCK_THRESHOLD,
|
||||
@@ -22,8 +19,12 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
RECOMMENDED_MAX_TOKENS,
|
||||
RECOMMENDED_TOP_K,
|
||||
@@ -114,10 +115,16 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
assert result2["subentries"] == [
|
||||
{
|
||||
"subentry_type": "conversation",
|
||||
"data": RECOMMENDED_OPTIONS,
|
||||
"title": "Google Conversation",
|
||||
"data": RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
"title": DEFAULT_CONVERSATION_NAME,
|
||||
"unique_id": None,
|
||||
}
|
||||
},
|
||||
{
|
||||
"subentry_type": "ai_task",
|
||||
"data": RECOMMENDED_AI_TASK_OPTIONS,
|
||||
"title": DEFAULT_AI_TASK_NAME,
|
||||
"unique_id": None,
|
||||
},
|
||||
]
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
@@ -149,19 +156,65 @@ async def test_creating_conversation_subentry(
|
||||
):
|
||||
result2 = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_NAME: "Mock name", **RECOMMENDED_OPTIONS},
|
||||
{CONF_NAME: "Mock name", **RECOMMENDED_CONVERSATION_OPTIONS},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Mock name"
|
||||
|
||||
processed_options = RECOMMENDED_OPTIONS.copy()
|
||||
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
|
||||
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
|
||||
|
||||
assert result2["data"] == processed_options
|
||||
|
||||
|
||||
async def test_creating_ai_task_subentry(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test creating an AI task subentry."""
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
):
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "ai_task"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "set_options"
|
||||
assert not result["errors"]
|
||||
|
||||
old_subentries = set(mock_config_entry.subentries)
|
||||
|
||||
with patch(
|
||||
"google.genai.models.AsyncModels.list",
|
||||
return_value=get_models_pager(),
|
||||
):
|
||||
result2 = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_NAME: "Mock AI Task", **RECOMMENDED_AI_TASK_OPTIONS},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["title"] == "Mock AI Task"
|
||||
assert result2["data"] == RECOMMENDED_AI_TASK_OPTIONS
|
||||
|
||||
assert len(mock_config_entry.subentries) == 3
|
||||
|
||||
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
|
||||
new_subentry = mock_config_entry.subentries[new_subentry_id]
|
||||
|
||||
assert new_subentry.subentry_type == "ai_task"
|
||||
assert new_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
assert new_subentry.title == "Mock AI Task"
|
||||
|
||||
|
||||
def will_options_be_rendered_again(current_options, new_options) -> bool:
|
||||
"""Determine if options will be rendered again."""
|
||||
return current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED)
|
||||
|
||||
@@ -10,7 +10,12 @@ from syrupy.assertion import SnapshotAssertion
|
||||
from homeassistant.components.google_generative_ai_conversation import (
|
||||
async_migrate_entry,
|
||||
)
|
||||
from homeassistant.components.google_generative_ai_conversation.const import DOMAIN
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
DEFAULT_AI_TASK_NAME,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_AI_TASK_OPTIONS,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@@ -441,16 +446,27 @@ async def test_migration_from_v1_to_v2(
|
||||
assert mock_config_entry.data == {"api_key": "1234"}
|
||||
assert mock_config_entry.options == {}
|
||||
|
||||
assert len(mock_config_entry.subentries) == 1
|
||||
assert len(mock_config_entry.subentries) == 2
|
||||
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
assert subentry.unique_id is None
|
||||
assert subentry.title == "Google Conversation"
|
||||
assert subentry.subentry_type == "conversation"
|
||||
assert subentry.data == OPTIONS
|
||||
subentries = {
|
||||
subentry.subentry_type: subentry
|
||||
for subentry in mock_config_entry.subentries.values()
|
||||
}
|
||||
|
||||
conversation_subentry = subentries["conversation"]
|
||||
assert conversation_subentry.unique_id is None
|
||||
assert conversation_subentry.title == DEFAULT_CONVERSATION_NAME
|
||||
assert conversation_subentry.subentry_type == "conversation"
|
||||
assert conversation_subentry.data == OPTIONS
|
||||
|
||||
migrated_entity = entity_registry.async_get(entity.entity_id)
|
||||
assert migrated_entity is not None
|
||||
assert migrated_entity.config_entry_id == mock_config_entry.entry_id
|
||||
assert migrated_entity.config_subentry_id == subentry.subentry_id
|
||||
assert migrated_entity.config_subentry_id == conversation_subentry.subentry_id
|
||||
assert migrated_entity.device_id == device.id
|
||||
|
||||
ai_task_subentry = subentries["ai_task"]
|
||||
assert ai_task_subentry.unique_id is None
|
||||
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
|
||||
assert ai_task_subentry.subentry_type == "ai_task"
|
||||
assert ai_task_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
|
||||
|
||||
Reference in New Issue
Block a user