forked from home-assistant/core
Address comments
This commit is contained in:
@@ -232,14 +232,17 @@ async def async_migrate_entry(
|
||||
|
||||
# Migrate conversation entity to be linked to subentry
|
||||
ent_reg = er.async_get(hass)
|
||||
for entity_entry in er.async_entries_for_config_entry(ent_reg, entry.entry_id):
|
||||
if entity_entry.domain == Platform.CONVERSATION:
|
||||
ent_reg.async_update_entity(
|
||||
entity_entry.entity_id,
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
new_unique_id=subentry.subentry_id,
|
||||
)
|
||||
break
|
||||
conversation_entity = ent_reg.async_get_entity_id(
|
||||
"conversation",
|
||||
DOMAIN,
|
||||
entry.entry_id,
|
||||
)
|
||||
if conversation_entity is not None:
|
||||
ent_reg.async_update_entity(
|
||||
conversation_entity,
|
||||
config_subentry_id=subentry.subentry_id,
|
||||
new_unique_id=subentry.subentry_id,
|
||||
)
|
||||
|
||||
# Remove options from the main entry
|
||||
hass.config_entries.async_update_entry(
|
||||
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from google import genai
|
||||
from google.genai.errors import APIError, ClientError
|
||||
@@ -177,38 +177,36 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
"""Flow for managing conversation subentries."""
|
||||
|
||||
last_rendered_recommended = False
|
||||
is_new: bool
|
||||
start_data: dict[str, Any]
|
||||
|
||||
@property
|
||||
def _genai_client(self) -> genai.Client:
|
||||
"""Return the Google Generative AI client."""
|
||||
return self._get_entry().runtime_data
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> SubentryFlowResult:
|
||||
"""Add a subentry."""
|
||||
self.is_new = True
|
||||
self.start_data = RECOMMENDED_OPTIONS.copy()
|
||||
return await self.async_step_set_options()
|
||||
|
||||
async def async_step_reconfigure(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> SubentryFlowResult:
|
||||
"""Handle reconfiguration of a subentry."""
|
||||
self.is_new = False
|
||||
self.start_data = self._get_reconfigure_subentry().data.copy()
|
||||
return await self.async_step_set_options()
|
||||
@property
|
||||
def _is_new(self) -> bool:
|
||||
"""Return if this is a new subentry."""
|
||||
return self.source == "user"
|
||||
|
||||
async def async_step_set_options(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> SubentryFlowResult:
|
||||
"""Set conversation options."""
|
||||
options = self.start_data
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
if user_input is not None:
|
||||
if user_input is None:
|
||||
if self._is_new:
|
||||
options = RECOMMENDED_OPTIONS.copy()
|
||||
else:
|
||||
# If this is a reconfiguration, we need to copy the existing options
|
||||
# so that we can show the current values in the form.
|
||||
options = self._get_reconfigure_subentry().data.copy()
|
||||
|
||||
self.last_rendered_recommended = cast(
|
||||
bool, options.get(CONF_RECOMMENDED, False)
|
||||
)
|
||||
|
||||
else:
|
||||
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
||||
if not user_input.get(CONF_LLM_HASS_API):
|
||||
user_input.pop(CONF_LLM_HASS_API, None)
|
||||
@@ -217,7 +215,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
user_input.get(CONF_LLM_HASS_API)
|
||||
and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True
|
||||
):
|
||||
if self.is_new:
|
||||
if self._is_new:
|
||||
return self.async_create_entry(
|
||||
title=user_input.pop(CONF_NAME),
|
||||
data=user_input,
|
||||
@@ -234,16 +232,17 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
||||
|
||||
options = user_input
|
||||
else:
|
||||
self.last_rendered_recommended = options.get(CONF_RECOMMENDED, False)
|
||||
|
||||
schema = await google_generative_ai_config_option_schema(
|
||||
self.hass, self.is_new, options, self._genai_client
|
||||
self.hass, self._is_new, options, self._genai_client
|
||||
)
|
||||
return self.async_show_form(
|
||||
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
|
||||
)
|
||||
|
||||
async_step_reconfigure = async_step_set_options
|
||||
async_step_user = async_step_set_options
|
||||
|
||||
|
||||
async def google_generative_ai_config_option_schema(
|
||||
hass: HomeAssistant,
|
||||
|
@@ -6,7 +6,7 @@ DOMAIN = "google_generative_ai_conversation"
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
CONF_PROMPT = "prompt"
|
||||
|
||||
DEFAULT_CONVERSATION_NAME = "Google Conversation"
|
||||
DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
|
||||
|
||||
ATTR_MODEL = "model"
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
|
@@ -22,6 +22,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
CONF_TOP_K,
|
||||
CONF_TOP_P,
|
||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CHAT_MODEL,
|
||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||
@@ -115,7 +116,7 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
{
|
||||
"subentry_type": "conversation",
|
||||
"data": RECOMMENDED_OPTIONS,
|
||||
"title": "Google Conversation",
|
||||
"title": DEFAULT_CONVERSATION_NAME,
|
||||
"unique_id": None,
|
||||
}
|
||||
]
|
||||
@@ -425,7 +426,10 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
|
||||
"""Test the reauth flow."""
|
||||
hass.config.components.add("google_generative_ai_conversation")
|
||||
mock_config_entry = MockConfigEntry(
|
||||
domain=DOMAIN, state=config_entries.ConfigEntryState.LOADED, title="Gemini"
|
||||
domain=DOMAIN,
|
||||
state=config_entries.ConfigEntryState.LOADED,
|
||||
title="Gemini",
|
||||
version=2,
|
||||
)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
mock_config_entry.async_start_reauth(hass)
|
||||
|
@@ -64,7 +64,7 @@ async def test_error_handling(
|
||||
"hello",
|
||||
None,
|
||||
Context(),
|
||||
agent_id="conversation.google_conversation",
|
||||
agent_id="conversation.google_ai_conversation",
|
||||
)
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
assert result.response.error_code == "unknown", result
|
||||
@@ -82,7 +82,7 @@ async def test_function_call(
|
||||
mock_send_message_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test function calling."""
|
||||
agent_id = "conversation.google_conversation"
|
||||
agent_id = "conversation.google_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
messages = [
|
||||
@@ -212,7 +212,7 @@ async def test_google_search_tool_is_sent(
|
||||
mock_send_message_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test if the Google Search tool is sent to the model."""
|
||||
agent_id = "conversation.google_conversation"
|
||||
agent_id = "conversation.google_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
messages = [
|
||||
@@ -278,7 +278,7 @@ async def test_blocked_response(
|
||||
mock_send_message_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test blocked response."""
|
||||
agent_id = "conversation.google_conversation"
|
||||
agent_id = "conversation.google_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
messages = [
|
||||
@@ -328,7 +328,7 @@ async def test_empty_response(
|
||||
) -> None:
|
||||
"""Test empty response."""
|
||||
|
||||
agent_id = "conversation.google_conversation"
|
||||
agent_id = "conversation.google_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
messages = [
|
||||
@@ -371,7 +371,7 @@ async def test_none_response(
|
||||
mock_send_message_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test None response."""
|
||||
agent_id = "conversation.google_conversation"
|
||||
agent_id = "conversation.google_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
messages = [
|
||||
@@ -417,7 +417,7 @@ async def test_converse_error(
|
||||
"hello",
|
||||
None,
|
||||
Context(),
|
||||
agent_id="conversation.google_conversation",
|
||||
agent_id="conversation.google_ai_conversation",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||
@@ -595,7 +595,7 @@ async def test_empty_content_in_chat_history(
|
||||
mock_send_message_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead."""
|
||||
agent_id = "conversation.google_conversation"
|
||||
agent_id = "conversation.google_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
messages = [
|
||||
@@ -650,7 +650,7 @@ async def test_history_always_user_first_turn(
|
||||
) -> None:
|
||||
"""Test that the user is always first in the chat history."""
|
||||
|
||||
agent_id = "conversation.google_conversation"
|
||||
agent_id = "conversation.google_ai_conversation"
|
||||
context = Context()
|
||||
|
||||
messages = [
|
||||
@@ -676,7 +676,7 @@ async def test_history_always_user_first_turn(
|
||||
|
||||
mock_chat_log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(
|
||||
agent_id="conversation.google_conversation",
|
||||
agent_id="conversation.google_ai_conversation",
|
||||
content="Garage door left open, do you want to close it?",
|
||||
)
|
||||
)
|
||||
|
@@ -10,7 +10,10 @@ from syrupy.assertion import SnapshotAssertion
|
||||
from homeassistant.components.google_generative_ai_conversation import (
|
||||
async_migrate_entry,
|
||||
)
|
||||
from homeassistant.components.google_generative_ai_conversation.const import DOMAIN
|
||||
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@@ -427,7 +430,7 @@ async def test_migration_from_v1_to_v2(
|
||||
entity = entity_registry.async_get_or_create(
|
||||
"conversation",
|
||||
DOMAIN,
|
||||
"mock_config_entry.entry_id",
|
||||
mock_config_entry.entry_id,
|
||||
config_entry=mock_config_entry,
|
||||
device_id=device.id,
|
||||
suggested_object_id="google_generative_ai_conversation",
|
||||
@@ -445,7 +448,7 @@ async def test_migration_from_v1_to_v2(
|
||||
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
assert subentry.unique_id is None
|
||||
assert subentry.title == "Google Conversation"
|
||||
assert subentry.title == DEFAULT_CONVERSATION_NAME
|
||||
assert subentry.subentry_type == "conversation"
|
||||
assert subentry.data == OPTIONS
|
||||
|
||||
|
@@ -122,7 +122,9 @@ async def mock_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||
async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||
"""Mock config entry setup."""
|
||||
default_config = {tts.CONF_LANG: "en-US"}
|
||||
config_entry = MockConfigEntry(domain=DOMAIN, data=default_config | config)
|
||||
config_entry = MockConfigEntry(
|
||||
domain=DOMAIN, data=default_config | config, version=2
|
||||
)
|
||||
|
||||
client_mock = Mock()
|
||||
client_mock.models.get = None
|
||||
|
Reference in New Issue
Block a user