Add Google AI Task platform

This commit is contained in:
Paulus Schoutsen
2025-06-13 13:15:26 -04:00
parent 0cf7952964
commit 356f8aad1d
10 changed files with 405 additions and 48 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
import mimetypes
from pathlib import Path
from types import MappingProxyType
from google.genai import Client
from google.genai.errors import APIError, ClientError
@ -37,9 +38,11 @@ from homeassistant.helpers.typing import ConfigType
from .const import (
CONF_CHAT_MODEL,
CONF_PROMPT,
DEFAULT_AI_TASK_NAME,
DOMAIN,
FILE_POLLING_INTERVAL_SECONDS,
LOGGER,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
TIMEOUT_MILLIS,
)
@ -50,6 +53,7 @@ CONF_FILENAMES = "filenames"
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (
Platform.AI_TASK,
Platform.CONVERSATION,
Platform.TTS,
)
@ -280,3 +284,25 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
options={},
version=2,
)
async def async_migrate_entry(
hass: HomeAssistant, entry: GoogleGenerativeAIConfigEntry
) -> bool:
"""Migrate old entry."""
if entry.version == 2 and entry.minor_version == 0:
# Migrate from version 2.0 to 2.1
# Add AI Task subentry with default options
hass.config_entries.async_add_subentry(
entry,
ConfigSubentry(
data=MappingProxyType(RECOMMENDED_AI_TASK_OPTIONS),
subentry_type="ai_task",
title=DEFAULT_AI_TASK_NAME,
unique_id=None,
),
)
hass.config_entries.async_update_entry(entry, minor_version=1)
return True
return False

View File

@ -0,0 +1,57 @@
"""AI Task integration for Google Generative AI Conversation."""
from __future__ import annotations
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .const import LOGGER
from .entity import ERROR_GETTING_RESPONSE, GoogleGenerativeAILLMBaseEntity
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task":
continue
async_add_entities(
[GoogleGenerativeAITaskEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)
class GoogleGenerativeAITaskEntity(
ai_task.AITaskEntity,
GoogleGenerativeAILLMBaseEntity,
):
"""Google Generative AI AI Task entity."""
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_TEXT
async def _async_generate_text(
self,
task: ai_task.GenTextTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenTextTaskResult:
"""Handle a generate text task."""
await self._async_handle_chat_log(chat_log)
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
LOGGER.error(
"Last content in chat log is not an AssistantContent: %s. This could be due to the model not returning a valid response",
chat_log.content[-1],
)
raise HomeAssistantError(ERROR_GETTING_RESPONSE)
return ai_task.GenTextTaskResult(
conversation_id=chat_log.conversation_id,
text=chat_log.content[-1].content or "",
)

View File

@ -46,9 +46,12 @@ from .const import (
CONF_TOP_K,
CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
@ -66,12 +69,6 @@ STEP_API_DATA_SCHEMA = vol.Schema(
}
)
RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
}
async def validate_input(data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect.
@ -123,10 +120,16 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
subentries=[
{
"subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS,
"data": RECOMMENDED_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
},
{
"subentry_type": "ai_task",
"data": RECOMMENDED_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
],
)
return self.async_show_form(
@ -172,10 +175,13 @@ class GoogleGenerativeAIConfigFlow(ConfigFlow, domain=DOMAIN):
cls, config_entry: ConfigEntry
) -> dict[str, type[ConfigSubentryFlow]]:
"""Return subentries supported by this integration."""
return {"conversation": ConversationSubentryFlowHandler}
return {
"conversation": LLMSubentryFlowHandler,
"ai_task": LLMSubentryFlowHandler,
}
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
class LLMSubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing conversation subentries."""
last_rendered_recommended = False
@ -201,12 +207,14 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
errors: dict[str, str] = {}
if user_input is None:
if self._is_new:
options = RECOMMENDED_OPTIONS.copy()
else:
if not self._is_new:
# If this is a reconfiguration, we need to copy the existing options
# so that we can show the current values in the form.
options = self._get_reconfigure_subentry().data.copy()
elif self._subentry_type == "ai_task":
options = RECOMMENDED_AI_TASK_OPTIONS.copy()
else:
options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
self.last_rendered_recommended = cast(
bool, options.get(CONF_RECOMMENDED, False)
@ -216,7 +224,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if not user_input.get(CONF_LLM_HASS_API):
user_input.pop(CONF_LLM_HASS_API, None)
# Don't allow to save options that enable the Google Seearch tool with an Assist API
# Don't allow to save options that enable the Google Search tool with an Assist API
if not (
user_input.get(CONF_LLM_HASS_API)
and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True
@ -240,7 +248,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
options = user_input
schema = await google_generative_ai_config_option_schema(
self.hass, self._is_new, options, self._genai_client
self.hass, self._is_new, self._subentry_type, options, self._genai_client
)
return self.async_show_form(
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
@ -253,6 +261,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
async def google_generative_ai_config_option_schema(
hass: HomeAssistant,
is_new: bool,
subentry_type: str,
options: Mapping[str, Any],
genai_client: genai.Client,
) -> dict:
@ -270,37 +279,56 @@ async def google_generative_ai_config_option_schema(
suggested_llm_apis = [suggested_llm_apis]
if is_new:
if CONF_NAME in options:
default_name = options[CONF_NAME]
elif subentry_type == "ai_task":
default_name = DEFAULT_AI_TASK_NAME
else:
default_name = DEFAULT_CONVERSATION_NAME
schema: dict[vol.Required | vol.Optional, Any] = {
vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME): str,
vol.Required(CONF_NAME, default=default_name): str,
}
else:
schema = {}
schema.update(
{
vol.Optional(
CONF_PROMPT,
description={
"suggested_value": options.get(
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
)
},
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": suggested_llm_apis},
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
}
)
if subentry_type == "conversation":
schema.update(
{
vol.Optional(
CONF_PROMPT,
description={
"suggested_value": options.get(
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
)
},
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": suggested_llm_apis},
): SelectSelector(
SelectSelectorConfig(options=hass_apis, multiple=True)
),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
}
)
else:
# For ai_task and tts subentry types
schema.update(
{
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
}
)
if options.get(CONF_RECOMMENDED):
return schema
api_models_pager = await genai_client.aio.models.list(config={"query_base": True})
api_models = [api_model async for api_model in api_models_pager]
models = [
SelectOptionDict(
label=api_model.display_name,
@ -396,13 +424,17 @@ async def google_generative_ai_config_option_schema(
},
default=RECOMMENDED_HARM_BLOCK_THRESHOLD,
): harm_block_thresholds_selector,
}
)
if subentry_type == "conversation":
schema[
vol.Optional(
CONF_USE_GOOGLE_SEARCH_TOOL,
description={
"suggested_value": options.get(CONF_USE_GOOGLE_SEARCH_TOOL),
},
default=RECOMMENDED_USE_GOOGLE_SEARCH_TOOL,
): bool,
}
)
)
] = bool
return schema

View File

@ -2,11 +2,16 @@
import logging
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.helpers import llm
DOMAIN = "google_generative_ai_conversation"
LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
DEFAULT_AI_TASK_NAME = "Google AI Task"
ATTR_MODEL = "model"
CONF_RECOMMENDED = "recommended"
@ -31,3 +36,13 @@ RECOMMENDED_USE_GOOGLE_SEARCH_TOOL = False
TIMEOUT_MILLIS = 10000
FILE_POLLING_INTERVAL_SECONDS = 0.05
RECOMMENDED_CONVERSATION_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
}
RECOMMENDED_AI_TASK_OPTIONS = {
CONF_RECOMMENDED: True,
}

View File

@ -61,6 +61,33 @@
"error": {
"invalid_google_search_option": "Google Search can only be enabled if nothing is selected in the \"Control Home Assistant\" setting."
}
},
"ai_task": {
"initiate_flow": {
"user": "Add Generate data with AI service",
"reconfigure": "Reconfigure Generate data with AI service"
},
"entry_type": "Generate data with AI service",
"step": {
"set_options": {
"data": {
"name": "[%key:common::config_flow::data::name%]",
"recommended": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::recommended%]",
"chat_model": "[%key:common::generic::model%]",
"temperature": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::temperature%]",
"top_p": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_p%]",
"top_k": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::top_k%]",
"max_tokens": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::max_tokens%]",
"harassment_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::harassment_block_threshold%]",
"hate_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::hate_block_threshold%]",
"sexual_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::sexual_block_threshold%]",
"dangerous_block_threshold": "[%key:component::google_generative_ai_conversation::config_subentries::conversation::step::set_options::data::dangerous_block_threshold%]"
}
}
},
"abort": {
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
}
}
},
"services": {

View File

@ -3420,6 +3420,11 @@ class ConfigSubentryFlow(
"""Return config entry id."""
return self.handler[0]
@property
def _subentry_type(self) -> str:
"""Return type of subentry we are editing/creating."""
return self.handler[1]
@callback
def _get_entry(self) -> ConfigEntry:
"""Return the config entry linked to the current context."""

View File

@ -7,6 +7,7 @@ import pytest
from homeassistant.components.google_generative_ai_conversation.const import (
CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
)
from homeassistant.config_entries import ConfigEntry
@ -28,13 +29,20 @@ def mock_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"api_key": "bla",
},
version=2,
minor_version=1,
subentries_data=[
{
"data": {},
"subentry_type": "conversation",
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
},
{
"data": {},
"subentry_type": "ai_task",
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
],
)
entry.runtime_data = Mock()

View File

@ -0,0 +1,62 @@
"""Test AI Task platform of Google Generative AI Conversation integration."""
from unittest.mock import AsyncMock
from google.genai.types import GenerateContentResponse
import pytest
from homeassistant.components import ai_task
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from tests.common import MockConfigEntry
from tests.components.conversation import (
MockChatLog,
mock_chat_log, # noqa: F401
)
@pytest.mark.usefixtures("mock_init_component")
async def test_run_task(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_chat_log: MockChatLog, # noqa: F811
mock_send_message_stream: AsyncMock,
entity_registry: er.EntityRegistry,
) -> None:
"""Test empty response."""
entity_id = "ai_task.google_ai_task"
# Ensure it's linked to the subentry
entity_entry = entity_registry.async_get(entity_id)
ai_task_entry = next(
iter(
entry
for entry in mock_config_entry.subentries.values()
if entry.subentry_type == "ai_task"
)
)
assert entity_entry.config_entry_id == mock_config_entry.entry_id
assert entity_entry.config_subentry_id == ai_task_entry.subentry_id
mock_send_message_stream.return_value = [
[
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [{"text": "Hi there!"}],
"role": "model",
},
}
],
),
],
]
result = await ai_task.async_generate_text(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Test prompt",
)
assert result.text == "Hi there!"

View File

@ -6,9 +6,6 @@ import pytest
from requests.exceptions import Timeout
from homeassistant import config_entries
from homeassistant.components.google_generative_ai_conversation.config_flow import (
RECOMMENDED_OPTIONS,
)
from homeassistant.components.google_generative_ai_conversation.const import (
CONF_CHAT_MODEL,
CONF_DANGEROUS_BLOCK_THRESHOLD,
@ -22,9 +19,12 @@ from homeassistant.components.google_generative_ai_conversation.const import (
CONF_TOP_K,
CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_CONVERSATION_OPTIONS,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TOP_K,
@ -115,10 +115,16 @@ async def test_form(hass: HomeAssistant) -> None:
assert result2["subentries"] == [
{
"subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS,
"data": RECOMMENDED_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
},
{
"subentry_type": "ai_task",
"data": RECOMMENDED_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME,
"unique_id": None,
},
]
assert len(mock_setup_entry.mock_calls) == 1
@ -172,26 +178,72 @@ async def test_creating_conversation_subentry(
):
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{CONF_NAME: "Mock name", **RECOMMENDED_OPTIONS},
{CONF_NAME: "Mock name", **RECOMMENDED_CONVERSATION_OPTIONS},
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "Mock name"
processed_options = RECOMMENDED_OPTIONS.copy()
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy()
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
assert result2["data"] == processed_options
async def test_creating_ai_task_subentry(
hass: HomeAssistant,
mock_init_component: None,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating an AI task subentry."""
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task"),
context={"source": config_entries.SOURCE_USER},
)
assert result["type"] is FlowResultType.FORM, result
assert result["step_id"] == "set_options"
assert not result["errors"]
old_subentries = set(mock_config_entry.subentries)
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),
):
result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"],
{CONF_NAME: "Mock AI Task", **RECOMMENDED_AI_TASK_OPTIONS},
)
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "Mock AI Task"
assert result2["data"] == RECOMMENDED_AI_TASK_OPTIONS
assert len(mock_config_entry.subentries) == 3
new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0]
new_subentry = mock_config_entry.subentries[new_subentry_id]
assert new_subentry.subentry_type == "ai_task"
assert new_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
assert new_subentry.title == "Mock AI Task"
async def test_creating_conversation_subentry_not_loaded(
hass: HomeAssistant,
mock_init_component: None,
mock_config_entry: MockConfigEntry,
) -> None:
"""Test creating a conversation subentry."""
"""Test that subentry fails to init if entry not loaded."""
await hass.config_entries.async_unload(mock_config_entry.entry_id)
with patch(
"google.genai.models.AsyncModels.list",
return_value=get_models_pager(),

View File

@ -7,7 +7,13 @@ import pytest
from requests.exceptions import Timeout
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.google_generative_ai_conversation.const import DOMAIN
from homeassistant.components.google_generative_ai_conversation.const import (
DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME,
DOMAIN,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CONVERSATION_OPTIONS,
)
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import CONF_API_KEY
from homeassistant.core import HomeAssistant
@ -722,3 +728,70 @@ async def test_migration_from_v1_to_v2_with_same_keys(
)
assert device.identifiers == {(DOMAIN, subentry.subentry_id)}
assert device.id == device_2.id
async def test_migrate_entry_from_v2_0_to_v2_1(hass: HomeAssistant) -> None:
"""Test migration from version 2.0 to version 2.1."""
# Create a v2.0 config entry (subversion 0) with only conversation subentry
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_API_KEY: "test-api-key"},
version=2,
minor_version=0,
subentries_data=[
{
"data": RECOMMENDED_CONVERSATION_OPTIONS,
"subentry_type": "conversation",
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
],
)
mock_config_entry.add_to_hass(hass)
# Verify initial state
assert mock_config_entry.version == 2
assert mock_config_entry.minor_version == 0
assert len(mock_config_entry.subentries) == 1
assert (
list(mock_config_entry.subentries.values())[0].subentry_type == "conversation"
)
# Run setup to trigger migration
with patch(
"homeassistant.components.google_generative_ai_conversation.async_setup_entry",
return_value=True,
):
result = await hass.config_entries.async_setup(mock_config_entry.entry_id)
assert result is True
await hass.async_block_till_done()
# Verify migration completed
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
# Check version and subversion were updated
assert entry.version == 2
assert entry.minor_version == 1
# Check we now have both conversation and ai_task subentries
assert len(entry.subentries) == 2
subentries = {
subentry.subentry_type: subentry for subentry in entry.subentries.values()
}
assert "conversation" in subentries
assert "ai_task" in subentries
# Find and verify the ai_task subentry
ai_task_subentry = subentries["ai_task"]
assert ai_task_subentry is not None
assert ai_task_subentry.title == DEFAULT_AI_TASK_NAME
assert ai_task_subentry.data == RECOMMENDED_AI_TASK_OPTIONS
# Verify conversation subentry is still there and unchanged
conversation_subentry = subentries["conversation"]
assert conversation_subentry is not None
assert conversation_subentry.title == DEFAULT_CONVERSATION_NAME
assert conversation_subentry.data == RECOMMENDED_CONVERSATION_OPTIONS