Address comments

This commit is contained in:
Paulus Schoutsen
2025-06-22 12:01:08 +00:00
parent 01459d0f35
commit 95a79d5b0c
7 changed files with 60 additions and 49 deletions

View File

@@ -232,14 +232,17 @@ async def async_migrate_entry(
# Migrate conversation entity to be linked to subentry
ent_reg = er.async_get(hass)
for entity_entry in er.async_entries_for_config_entry(ent_reg, entry.entry_id):
if entity_entry.domain == Platform.CONVERSATION:
ent_reg.async_update_entity(
entity_entry.entity_id,
config_subentry_id=subentry.subentry_id,
new_unique_id=subentry.subentry_id,
)
break
conversation_entity = ent_reg.async_get_entity_id(
"conversation",
DOMAIN,
entry.entry_id,
)
if conversation_entity is not None:
ent_reg.async_update_entity(
conversation_entity,
config_subentry_id=subentry.subentry_id,
new_unique_id=subentry.subentry_id,
)
# Remove options from the main entry
hass.config_entries.async_update_entry(

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from collections.abc import Mapping
import logging
from typing import Any
from typing import Any, cast
from google import genai
from google.genai.errors import APIError, ClientError
@@ -177,38 +177,36 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing conversation subentries."""
last_rendered_recommended = False
is_new: bool
start_data: dict[str, Any]
@property
def _genai_client(self) -> genai.Client:
"""Return the Google Generative AI client."""
return self._get_entry().runtime_data
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Add a subentry."""
self.is_new = True
self.start_data = RECOMMENDED_OPTIONS.copy()
return await self.async_step_set_options()
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Handle reconfiguration of a subentry."""
self.is_new = False
self.start_data = self._get_reconfigure_subentry().data.copy()
return await self.async_step_set_options()
@property
def _is_new(self) -> bool:
"""Return if this is a new subentry."""
return self.source == "user"
async def async_step_set_options(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Set conversation options."""
options = self.start_data
errors: dict[str, str] = {}
if user_input is not None:
if user_input is None:
if self._is_new:
options = RECOMMENDED_OPTIONS.copy()
else:
# If this is a reconfiguration, we need to copy the existing options
# so that we can show the current values in the form.
options = self._get_reconfigure_subentry().data.copy()
self.last_rendered_recommended = cast(
bool, options.get(CONF_RECOMMENDED, False)
)
else:
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if not user_input.get(CONF_LLM_HASS_API):
user_input.pop(CONF_LLM_HASS_API, None)
@@ -217,7 +215,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
user_input.get(CONF_LLM_HASS_API)
and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True
):
if self.is_new:
if self._is_new:
return self.async_create_entry(
title=user_input.pop(CONF_NAME),
data=user_input,
@@ -234,16 +232,17 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
options = user_input
else:
self.last_rendered_recommended = options.get(CONF_RECOMMENDED, False)
schema = await google_generative_ai_config_option_schema(
self.hass, self.is_new, options, self._genai_client
self.hass, self._is_new, options, self._genai_client
)
return self.async_show_form(
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
)
async_step_reconfigure = async_step_set_options
async_step_user = async_step_set_options
async def google_generative_ai_config_option_schema(
hass: HomeAssistant,

View File

@@ -6,7 +6,7 @@ DOMAIN = "google_generative_ai_conversation"
LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
DEFAULT_CONVERSATION_NAME = "Google Conversation"
DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
ATTR_MODEL = "model"
CONF_RECOMMENDED = "recommended"

View File

@@ -22,6 +22,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
CONF_TOP_K,
CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_CONVERSATION_NAME,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
@@ -115,7 +116,7 @@ async def test_form(hass: HomeAssistant) -> None:
{
"subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS,
"title": "Google Conversation",
"title": DEFAULT_CONVERSATION_NAME,
"unique_id": None,
}
]
@@ -425,7 +426,10 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
"""Test the reauth flow."""
hass.config.components.add("google_generative_ai_conversation")
mock_config_entry = MockConfigEntry(
domain=DOMAIN, state=config_entries.ConfigEntryState.LOADED, title="Gemini"
domain=DOMAIN,
state=config_entries.ConfigEntryState.LOADED,
title="Gemini",
version=2,
)
mock_config_entry.add_to_hass(hass)
mock_config_entry.async_start_reauth(hass)

View File

@@ -64,7 +64,7 @@ async def test_error_handling(
"hello",
None,
Context(),
agent_id="conversation.google_conversation",
agent_id="conversation.google_ai_conversation",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result
@@ -82,7 +82,7 @@ async def test_function_call(
mock_send_message_stream: AsyncMock,
) -> None:
"""Test function calling."""
agent_id = "conversation.google_conversation"
agent_id = "conversation.google_ai_conversation"
context = Context()
messages = [
@@ -212,7 +212,7 @@ async def test_google_search_tool_is_sent(
mock_send_message_stream: AsyncMock,
) -> None:
"""Test if the Google Search tool is sent to the model."""
agent_id = "conversation.google_conversation"
agent_id = "conversation.google_ai_conversation"
context = Context()
messages = [
@@ -278,7 +278,7 @@ async def test_blocked_response(
mock_send_message_stream: AsyncMock,
) -> None:
"""Test blocked response."""
agent_id = "conversation.google_conversation"
agent_id = "conversation.google_ai_conversation"
context = Context()
messages = [
@@ -328,7 +328,7 @@ async def test_empty_response(
) -> None:
"""Test empty response."""
agent_id = "conversation.google_conversation"
agent_id = "conversation.google_ai_conversation"
context = Context()
messages = [
@@ -371,7 +371,7 @@ async def test_none_response(
mock_send_message_stream: AsyncMock,
) -> None:
"""Test None response."""
agent_id = "conversation.google_conversation"
agent_id = "conversation.google_ai_conversation"
context = Context()
messages = [
@@ -417,7 +417,7 @@ async def test_converse_error(
"hello",
None,
Context(),
agent_id="conversation.google_conversation",
agent_id="conversation.google_ai_conversation",
)
assert result.response.response_type == intent.IntentResponseType.ERROR, result
@@ -595,7 +595,7 @@ async def test_empty_content_in_chat_history(
mock_send_message_stream: AsyncMock,
) -> None:
"""Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead."""
agent_id = "conversation.google_conversation"
agent_id = "conversation.google_ai_conversation"
context = Context()
messages = [
@@ -650,7 +650,7 @@ async def test_history_always_user_first_turn(
) -> None:
"""Test that the user is always first in the chat history."""
agent_id = "conversation.google_conversation"
agent_id = "conversation.google_ai_conversation"
context = Context()
messages = [
@@ -676,7 +676,7 @@ async def test_history_always_user_first_turn(
mock_chat_log.async_add_assistant_content_without_tools(
conversation.AssistantContent(
agent_id="conversation.google_conversation",
agent_id="conversation.google_ai_conversation",
content="Garage door left open, do you want to close it?",
)
)

View File

@@ -10,7 +10,10 @@ from syrupy.assertion import SnapshotAssertion
from homeassistant.components.google_generative_ai_conversation import (
async_migrate_entry,
)
from homeassistant.components.google_generative_ai_conversation.const import DOMAIN
from homeassistant.components.google_generative_ai_conversation.const import (
DEFAULT_CONVERSATION_NAME,
DOMAIN,
)
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@@ -427,7 +430,7 @@ async def test_migration_from_v1_to_v2(
entity = entity_registry.async_get_or_create(
"conversation",
DOMAIN,
"mock_config_entry.entry_id",
mock_config_entry.entry_id,
config_entry=mock_config_entry,
device_id=device.id,
suggested_object_id="google_generative_ai_conversation",
@@ -445,7 +448,7 @@ async def test_migration_from_v1_to_v2(
subentry = next(iter(mock_config_entry.subentries.values()))
assert subentry.unique_id is None
assert subentry.title == "Google Conversation"
assert subentry.title == DEFAULT_CONVERSATION_NAME
assert subentry.subentry_type == "conversation"
assert subentry.data == OPTIONS

View File

@@ -122,7 +122,9 @@ async def mock_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Mock config entry setup."""
default_config = {tts.CONF_LANG: "en-US"}
config_entry = MockConfigEntry(domain=DOMAIN, data=default_config | config)
config_entry = MockConfigEntry(
domain=DOMAIN, data=default_config | config, version=2
)
client_mock = Mock()
client_mock.models.get = None