Address comments

This commit is contained in:
Paulus Schoutsen
2025-06-22 12:01:08 +00:00
parent 01459d0f35
commit 95a79d5b0c
7 changed files with 60 additions and 49 deletions

View File

@@ -232,14 +232,17 @@ async def async_migrate_entry(
# Migrate conversation entity to be linked to subentry # Migrate conversation entity to be linked to subentry
ent_reg = er.async_get(hass) ent_reg = er.async_get(hass)
for entity_entry in er.async_entries_for_config_entry(ent_reg, entry.entry_id): conversation_entity = ent_reg.async_get_entity_id(
if entity_entry.domain == Platform.CONVERSATION: "conversation",
DOMAIN,
entry.entry_id,
)
if conversation_entity is not None:
ent_reg.async_update_entity( ent_reg.async_update_entity(
entity_entry.entity_id, conversation_entity,
config_subentry_id=subentry.subentry_id, config_subentry_id=subentry.subentry_id,
new_unique_id=subentry.subentry_id, new_unique_id=subentry.subentry_id,
) )
break
# Remove options from the main entry # Remove options from the main entry
hass.config_entries.async_update_entry( hass.config_entries.async_update_entry(

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import logging import logging
from typing import Any from typing import Any, cast
from google import genai from google import genai
from google.genai.errors import APIError, ClientError from google.genai.errors import APIError, ClientError
@@ -177,38 +177,36 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
"""Flow for managing conversation subentries.""" """Flow for managing conversation subentries."""
last_rendered_recommended = False last_rendered_recommended = False
is_new: bool
start_data: dict[str, Any]
@property @property
def _genai_client(self) -> genai.Client: def _genai_client(self) -> genai.Client:
"""Return the Google Generative AI client.""" """Return the Google Generative AI client."""
return self._get_entry().runtime_data return self._get_entry().runtime_data
async def async_step_user( @property
self, user_input: dict[str, Any] | None = None def _is_new(self) -> bool:
) -> SubentryFlowResult: """Return if this is a new subentry."""
"""Add a subentry.""" return self.source == "user"
self.is_new = True
self.start_data = RECOMMENDED_OPTIONS.copy()
return await self.async_step_set_options()
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult:
"""Handle reconfiguration of a subentry."""
self.is_new = False
self.start_data = self._get_reconfigure_subentry().data.copy()
return await self.async_step_set_options()
async def async_step_set_options( async def async_step_set_options(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> SubentryFlowResult: ) -> SubentryFlowResult:
"""Set conversation options.""" """Set conversation options."""
options = self.start_data
errors: dict[str, str] = {} errors: dict[str, str] = {}
if user_input is not None: if user_input is None:
if self._is_new:
options = RECOMMENDED_OPTIONS.copy()
else:
# If this is a reconfiguration, we need to copy the existing options
# so that we can show the current values in the form.
options = self._get_reconfigure_subentry().data.copy()
self.last_rendered_recommended = cast(
bool, options.get(CONF_RECOMMENDED, False)
)
else:
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended: if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if not user_input.get(CONF_LLM_HASS_API): if not user_input.get(CONF_LLM_HASS_API):
user_input.pop(CONF_LLM_HASS_API, None) user_input.pop(CONF_LLM_HASS_API, None)
@@ -217,7 +215,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
user_input.get(CONF_LLM_HASS_API) user_input.get(CONF_LLM_HASS_API)
and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True
): ):
if self.is_new: if self._is_new:
return self.async_create_entry( return self.async_create_entry(
title=user_input.pop(CONF_NAME), title=user_input.pop(CONF_NAME),
data=user_input, data=user_input,
@@ -234,16 +232,17 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
self.last_rendered_recommended = user_input[CONF_RECOMMENDED] self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
options = user_input options = user_input
else:
self.last_rendered_recommended = options.get(CONF_RECOMMENDED, False)
schema = await google_generative_ai_config_option_schema( schema = await google_generative_ai_config_option_schema(
self.hass, self.is_new, options, self._genai_client self.hass, self._is_new, options, self._genai_client
) )
return self.async_show_form( return self.async_show_form(
step_id="set_options", data_schema=vol.Schema(schema), errors=errors step_id="set_options", data_schema=vol.Schema(schema), errors=errors
) )
async_step_reconfigure = async_step_set_options
async_step_user = async_step_set_options
async def google_generative_ai_config_option_schema( async def google_generative_ai_config_option_schema(
hass: HomeAssistant, hass: HomeAssistant,

View File

@@ -6,7 +6,7 @@ DOMAIN = "google_generative_ai_conversation"
LOGGER = logging.getLogger(__package__) LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt" CONF_PROMPT = "prompt"
DEFAULT_CONVERSATION_NAME = "Google Conversation" DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
ATTR_MODEL = "model" ATTR_MODEL = "model"
CONF_RECOMMENDED = "recommended" CONF_RECOMMENDED = "recommended"

View File

@@ -22,6 +22,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
CONF_TOP_K, CONF_TOP_K,
CONF_TOP_P, CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL, CONF_USE_GOOGLE_SEARCH_TOOL,
DEFAULT_CONVERSATION_NAME,
DOMAIN, DOMAIN,
RECOMMENDED_CHAT_MODEL, RECOMMENDED_CHAT_MODEL,
RECOMMENDED_HARM_BLOCK_THRESHOLD, RECOMMENDED_HARM_BLOCK_THRESHOLD,
@@ -115,7 +116,7 @@ async def test_form(hass: HomeAssistant) -> None:
{ {
"subentry_type": "conversation", "subentry_type": "conversation",
"data": RECOMMENDED_OPTIONS, "data": RECOMMENDED_OPTIONS,
"title": "Google Conversation", "title": DEFAULT_CONVERSATION_NAME,
"unique_id": None, "unique_id": None,
} }
] ]
@@ -425,7 +426,10 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
"""Test the reauth flow.""" """Test the reauth flow."""
hass.config.components.add("google_generative_ai_conversation") hass.config.components.add("google_generative_ai_conversation")
mock_config_entry = MockConfigEntry( mock_config_entry = MockConfigEntry(
domain=DOMAIN, state=config_entries.ConfigEntryState.LOADED, title="Gemini" domain=DOMAIN,
state=config_entries.ConfigEntryState.LOADED,
title="Gemini",
version=2,
) )
mock_config_entry.add_to_hass(hass) mock_config_entry.add_to_hass(hass)
mock_config_entry.async_start_reauth(hass) mock_config_entry.async_start_reauth(hass)

View File

@@ -64,7 +64,7 @@ async def test_error_handling(
"hello", "hello",
None, None,
Context(), Context(),
agent_id="conversation.google_conversation", agent_id="conversation.google_ai_conversation",
) )
assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.response_type == intent.IntentResponseType.ERROR, result
assert result.response.error_code == "unknown", result assert result.response.error_code == "unknown", result
@@ -82,7 +82,7 @@ async def test_function_call(
mock_send_message_stream: AsyncMock, mock_send_message_stream: AsyncMock,
) -> None: ) -> None:
"""Test function calling.""" """Test function calling."""
agent_id = "conversation.google_conversation" agent_id = "conversation.google_ai_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -212,7 +212,7 @@ async def test_google_search_tool_is_sent(
mock_send_message_stream: AsyncMock, mock_send_message_stream: AsyncMock,
) -> None: ) -> None:
"""Test if the Google Search tool is sent to the model.""" """Test if the Google Search tool is sent to the model."""
agent_id = "conversation.google_conversation" agent_id = "conversation.google_ai_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -278,7 +278,7 @@ async def test_blocked_response(
mock_send_message_stream: AsyncMock, mock_send_message_stream: AsyncMock,
) -> None: ) -> None:
"""Test blocked response.""" """Test blocked response."""
agent_id = "conversation.google_conversation" agent_id = "conversation.google_ai_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -328,7 +328,7 @@ async def test_empty_response(
) -> None: ) -> None:
"""Test empty response.""" """Test empty response."""
agent_id = "conversation.google_conversation" agent_id = "conversation.google_ai_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -371,7 +371,7 @@ async def test_none_response(
mock_send_message_stream: AsyncMock, mock_send_message_stream: AsyncMock,
) -> None: ) -> None:
"""Test None response.""" """Test None response."""
agent_id = "conversation.google_conversation" agent_id = "conversation.google_ai_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -417,7 +417,7 @@ async def test_converse_error(
"hello", "hello",
None, None,
Context(), Context(),
agent_id="conversation.google_conversation", agent_id="conversation.google_ai_conversation",
) )
assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.response_type == intent.IntentResponseType.ERROR, result
@@ -595,7 +595,7 @@ async def test_empty_content_in_chat_history(
mock_send_message_stream: AsyncMock, mock_send_message_stream: AsyncMock,
) -> None: ) -> None:
"""Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead.""" """Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead."""
agent_id = "conversation.google_conversation" agent_id = "conversation.google_ai_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -650,7 +650,7 @@ async def test_history_always_user_first_turn(
) -> None: ) -> None:
"""Test that the user is always first in the chat history.""" """Test that the user is always first in the chat history."""
agent_id = "conversation.google_conversation" agent_id = "conversation.google_ai_conversation"
context = Context() context = Context()
messages = [ messages = [
@@ -676,7 +676,7 @@ async def test_history_always_user_first_turn(
mock_chat_log.async_add_assistant_content_without_tools( mock_chat_log.async_add_assistant_content_without_tools(
conversation.AssistantContent( conversation.AssistantContent(
agent_id="conversation.google_conversation", agent_id="conversation.google_ai_conversation",
content="Garage door left open, do you want to close it?", content="Garage door left open, do you want to close it?",
) )
) )

View File

@@ -10,7 +10,10 @@ from syrupy.assertion import SnapshotAssertion
from homeassistant.components.google_generative_ai_conversation import ( from homeassistant.components.google_generative_ai_conversation import (
async_migrate_entry, async_migrate_entry,
) )
from homeassistant.components.google_generative_ai_conversation.const import DOMAIN from homeassistant.components.google_generative_ai_conversation.const import (
DEFAULT_CONVERSATION_NAME,
DOMAIN,
)
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
@@ -427,7 +430,7 @@ async def test_migration_from_v1_to_v2(
entity = entity_registry.async_get_or_create( entity = entity_registry.async_get_or_create(
"conversation", "conversation",
DOMAIN, DOMAIN,
"mock_config_entry.entry_id", mock_config_entry.entry_id,
config_entry=mock_config_entry, config_entry=mock_config_entry,
device_id=device.id, device_id=device.id,
suggested_object_id="google_generative_ai_conversation", suggested_object_id="google_generative_ai_conversation",
@@ -445,7 +448,7 @@ async def test_migration_from_v1_to_v2(
subentry = next(iter(mock_config_entry.subentries.values())) subentry = next(iter(mock_config_entry.subentries.values()))
assert subentry.unique_id is None assert subentry.unique_id is None
assert subentry.title == "Google Conversation" assert subentry.title == DEFAULT_CONVERSATION_NAME
assert subentry.subentry_type == "conversation" assert subentry.subentry_type == "conversation"
assert subentry.data == OPTIONS assert subentry.data == OPTIONS

View File

@@ -122,7 +122,9 @@ async def mock_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -> None: async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
"""Mock config entry setup.""" """Mock config entry setup."""
default_config = {tts.CONF_LANG: "en-US"} default_config = {tts.CONF_LANG: "en-US"}
config_entry = MockConfigEntry(domain=DOMAIN, data=default_config | config) config_entry = MockConfigEntry(
domain=DOMAIN, data=default_config | config, version=2
)
client_mock = Mock() client_mock = Mock()
client_mock.models.get = None client_mock.models.get = None