forked from home-assistant/core
Address comments
This commit is contained in:
@@ -232,14 +232,17 @@ async def async_migrate_entry(
|
|||||||
|
|
||||||
# Migrate conversation entity to be linked to subentry
|
# Migrate conversation entity to be linked to subentry
|
||||||
ent_reg = er.async_get(hass)
|
ent_reg = er.async_get(hass)
|
||||||
for entity_entry in er.async_entries_for_config_entry(ent_reg, entry.entry_id):
|
conversation_entity = ent_reg.async_get_entity_id(
|
||||||
if entity_entry.domain == Platform.CONVERSATION:
|
"conversation",
|
||||||
|
DOMAIN,
|
||||||
|
entry.entry_id,
|
||||||
|
)
|
||||||
|
if conversation_entity is not None:
|
||||||
ent_reg.async_update_entity(
|
ent_reg.async_update_entity(
|
||||||
entity_entry.entity_id,
|
conversation_entity,
|
||||||
config_subentry_id=subentry.subentry_id,
|
config_subentry_id=subentry.subentry_id,
|
||||||
new_unique_id=subentry.subentry_id,
|
new_unique_id=subentry.subentry_id,
|
||||||
)
|
)
|
||||||
break
|
|
||||||
|
|
||||||
# Remove options from the main entry
|
# Remove options from the main entry
|
||||||
hass.config_entries.async_update_entry(
|
hass.config_entries.async_update_entry(
|
||||||
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai.errors import APIError, ClientError
|
from google.genai.errors import APIError, ClientError
|
||||||
@@ -177,38 +177,36 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
"""Flow for managing conversation subentries."""
|
"""Flow for managing conversation subentries."""
|
||||||
|
|
||||||
last_rendered_recommended = False
|
last_rendered_recommended = False
|
||||||
is_new: bool
|
|
||||||
start_data: dict[str, Any]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _genai_client(self) -> genai.Client:
|
def _genai_client(self) -> genai.Client:
|
||||||
"""Return the Google Generative AI client."""
|
"""Return the Google Generative AI client."""
|
||||||
return self._get_entry().runtime_data
|
return self._get_entry().runtime_data
|
||||||
|
|
||||||
async def async_step_user(
|
@property
|
||||||
self, user_input: dict[str, Any] | None = None
|
def _is_new(self) -> bool:
|
||||||
) -> SubentryFlowResult:
|
"""Return if this is a new subentry."""
|
||||||
"""Add a subentry."""
|
return self.source == "user"
|
||||||
self.is_new = True
|
|
||||||
self.start_data = RECOMMENDED_OPTIONS.copy()
|
|
||||||
return await self.async_step_set_options()
|
|
||||||
|
|
||||||
async def async_step_reconfigure(
|
|
||||||
self, user_input: dict[str, Any] | None = None
|
|
||||||
) -> SubentryFlowResult:
|
|
||||||
"""Handle reconfiguration of a subentry."""
|
|
||||||
self.is_new = False
|
|
||||||
self.start_data = self._get_reconfigure_subentry().data.copy()
|
|
||||||
return await self.async_step_set_options()
|
|
||||||
|
|
||||||
async def async_step_set_options(
|
async def async_step_set_options(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
) -> SubentryFlowResult:
|
) -> SubentryFlowResult:
|
||||||
"""Set conversation options."""
|
"""Set conversation options."""
|
||||||
options = self.start_data
|
|
||||||
errors: dict[str, str] = {}
|
errors: dict[str, str] = {}
|
||||||
|
|
||||||
if user_input is not None:
|
if user_input is None:
|
||||||
|
if self._is_new:
|
||||||
|
options = RECOMMENDED_OPTIONS.copy()
|
||||||
|
else:
|
||||||
|
# If this is a reconfiguration, we need to copy the existing options
|
||||||
|
# so that we can show the current values in the form.
|
||||||
|
options = self._get_reconfigure_subentry().data.copy()
|
||||||
|
|
||||||
|
self.last_rendered_recommended = cast(
|
||||||
|
bool, options.get(CONF_RECOMMENDED, False)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
|
||||||
if not user_input.get(CONF_LLM_HASS_API):
|
if not user_input.get(CONF_LLM_HASS_API):
|
||||||
user_input.pop(CONF_LLM_HASS_API, None)
|
user_input.pop(CONF_LLM_HASS_API, None)
|
||||||
@@ -217,7 +215,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
user_input.get(CONF_LLM_HASS_API)
|
user_input.get(CONF_LLM_HASS_API)
|
||||||
and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True
|
and user_input.get(CONF_USE_GOOGLE_SEARCH_TOOL, False) is True
|
||||||
):
|
):
|
||||||
if self.is_new:
|
if self._is_new:
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=user_input.pop(CONF_NAME),
|
title=user_input.pop(CONF_NAME),
|
||||||
data=user_input,
|
data=user_input,
|
||||||
@@ -234,16 +232,17 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
|||||||
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]
|
||||||
|
|
||||||
options = user_input
|
options = user_input
|
||||||
else:
|
|
||||||
self.last_rendered_recommended = options.get(CONF_RECOMMENDED, False)
|
|
||||||
|
|
||||||
schema = await google_generative_ai_config_option_schema(
|
schema = await google_generative_ai_config_option_schema(
|
||||||
self.hass, self.is_new, options, self._genai_client
|
self.hass, self._is_new, options, self._genai_client
|
||||||
)
|
)
|
||||||
return self.async_show_form(
|
return self.async_show_form(
|
||||||
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
|
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async_step_reconfigure = async_step_set_options
|
||||||
|
async_step_user = async_step_set_options
|
||||||
|
|
||||||
|
|
||||||
async def google_generative_ai_config_option_schema(
|
async def google_generative_ai_config_option_schema(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
|
@@ -6,7 +6,7 @@ DOMAIN = "google_generative_ai_conversation"
|
|||||||
LOGGER = logging.getLogger(__package__)
|
LOGGER = logging.getLogger(__package__)
|
||||||
CONF_PROMPT = "prompt"
|
CONF_PROMPT = "prompt"
|
||||||
|
|
||||||
DEFAULT_CONVERSATION_NAME = "Google Conversation"
|
DEFAULT_CONVERSATION_NAME = "Google AI Conversation"
|
||||||
|
|
||||||
ATTR_MODEL = "model"
|
ATTR_MODEL = "model"
|
||||||
CONF_RECOMMENDED = "recommended"
|
CONF_RECOMMENDED = "recommended"
|
||||||
|
@@ -22,6 +22,7 @@ from homeassistant.components.google_generative_ai_conversation.const import (
|
|||||||
CONF_TOP_K,
|
CONF_TOP_K,
|
||||||
CONF_TOP_P,
|
CONF_TOP_P,
|
||||||
CONF_USE_GOOGLE_SEARCH_TOOL,
|
CONF_USE_GOOGLE_SEARCH_TOOL,
|
||||||
|
DEFAULT_CONVERSATION_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
RECOMMENDED_CHAT_MODEL,
|
RECOMMENDED_CHAT_MODEL,
|
||||||
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
RECOMMENDED_HARM_BLOCK_THRESHOLD,
|
||||||
@@ -115,7 +116,7 @@ async def test_form(hass: HomeAssistant) -> None:
|
|||||||
{
|
{
|
||||||
"subentry_type": "conversation",
|
"subentry_type": "conversation",
|
||||||
"data": RECOMMENDED_OPTIONS,
|
"data": RECOMMENDED_OPTIONS,
|
||||||
"title": "Google Conversation",
|
"title": DEFAULT_CONVERSATION_NAME,
|
||||||
"unique_id": None,
|
"unique_id": None,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -425,7 +426,10 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
|
|||||||
"""Test the reauth flow."""
|
"""Test the reauth flow."""
|
||||||
hass.config.components.add("google_generative_ai_conversation")
|
hass.config.components.add("google_generative_ai_conversation")
|
||||||
mock_config_entry = MockConfigEntry(
|
mock_config_entry = MockConfigEntry(
|
||||||
domain=DOMAIN, state=config_entries.ConfigEntryState.LOADED, title="Gemini"
|
domain=DOMAIN,
|
||||||
|
state=config_entries.ConfigEntryState.LOADED,
|
||||||
|
title="Gemini",
|
||||||
|
version=2,
|
||||||
)
|
)
|
||||||
mock_config_entry.add_to_hass(hass)
|
mock_config_entry.add_to_hass(hass)
|
||||||
mock_config_entry.async_start_reauth(hass)
|
mock_config_entry.async_start_reauth(hass)
|
||||||
|
@@ -64,7 +64,7 @@ async def test_error_handling(
|
|||||||
"hello",
|
"hello",
|
||||||
None,
|
None,
|
||||||
Context(),
|
Context(),
|
||||||
agent_id="conversation.google_conversation",
|
agent_id="conversation.google_ai_conversation",
|
||||||
)
|
)
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
assert result.response.error_code == "unknown", result
|
assert result.response.error_code == "unknown", result
|
||||||
@@ -82,7 +82,7 @@ async def test_function_call(
|
|||||||
mock_send_message_stream: AsyncMock,
|
mock_send_message_stream: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test function calling."""
|
"""Test function calling."""
|
||||||
agent_id = "conversation.google_conversation"
|
agent_id = "conversation.google_ai_conversation"
|
||||||
context = Context()
|
context = Context()
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -212,7 +212,7 @@ async def test_google_search_tool_is_sent(
|
|||||||
mock_send_message_stream: AsyncMock,
|
mock_send_message_stream: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test if the Google Search tool is sent to the model."""
|
"""Test if the Google Search tool is sent to the model."""
|
||||||
agent_id = "conversation.google_conversation"
|
agent_id = "conversation.google_ai_conversation"
|
||||||
context = Context()
|
context = Context()
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -278,7 +278,7 @@ async def test_blocked_response(
|
|||||||
mock_send_message_stream: AsyncMock,
|
mock_send_message_stream: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test blocked response."""
|
"""Test blocked response."""
|
||||||
agent_id = "conversation.google_conversation"
|
agent_id = "conversation.google_ai_conversation"
|
||||||
context = Context()
|
context = Context()
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -328,7 +328,7 @@ async def test_empty_response(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test empty response."""
|
"""Test empty response."""
|
||||||
|
|
||||||
agent_id = "conversation.google_conversation"
|
agent_id = "conversation.google_ai_conversation"
|
||||||
context = Context()
|
context = Context()
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -371,7 +371,7 @@ async def test_none_response(
|
|||||||
mock_send_message_stream: AsyncMock,
|
mock_send_message_stream: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test None response."""
|
"""Test None response."""
|
||||||
agent_id = "conversation.google_conversation"
|
agent_id = "conversation.google_ai_conversation"
|
||||||
context = Context()
|
context = Context()
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -417,7 +417,7 @@ async def test_converse_error(
|
|||||||
"hello",
|
"hello",
|
||||||
None,
|
None,
|
||||||
Context(),
|
Context(),
|
||||||
agent_id="conversation.google_conversation",
|
agent_id="conversation.google_ai_conversation",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
assert result.response.response_type == intent.IntentResponseType.ERROR, result
|
||||||
@@ -595,7 +595,7 @@ async def test_empty_content_in_chat_history(
|
|||||||
mock_send_message_stream: AsyncMock,
|
mock_send_message_stream: AsyncMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead."""
|
"""Tests that in case of an empty entry in the chat history the google API will receive an injected space sign instead."""
|
||||||
agent_id = "conversation.google_conversation"
|
agent_id = "conversation.google_ai_conversation"
|
||||||
context = Context()
|
context = Context()
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -650,7 +650,7 @@ async def test_history_always_user_first_turn(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the user is always first in the chat history."""
|
"""Test that the user is always first in the chat history."""
|
||||||
|
|
||||||
agent_id = "conversation.google_conversation"
|
agent_id = "conversation.google_ai_conversation"
|
||||||
context = Context()
|
context = Context()
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -676,7 +676,7 @@ async def test_history_always_user_first_turn(
|
|||||||
|
|
||||||
mock_chat_log.async_add_assistant_content_without_tools(
|
mock_chat_log.async_add_assistant_content_without_tools(
|
||||||
conversation.AssistantContent(
|
conversation.AssistantContent(
|
||||||
agent_id="conversation.google_conversation",
|
agent_id="conversation.google_ai_conversation",
|
||||||
content="Garage door left open, do you want to close it?",
|
content="Garage door left open, do you want to close it?",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@@ -10,7 +10,10 @@ from syrupy.assertion import SnapshotAssertion
|
|||||||
from homeassistant.components.google_generative_ai_conversation import (
|
from homeassistant.components.google_generative_ai_conversation import (
|
||||||
async_migrate_entry,
|
async_migrate_entry,
|
||||||
)
|
)
|
||||||
from homeassistant.components.google_generative_ai_conversation.const import DOMAIN
|
from homeassistant.components.google_generative_ai_conversation.const import (
|
||||||
|
DEFAULT_CONVERSATION_NAME,
|
||||||
|
DOMAIN,
|
||||||
|
)
|
||||||
from homeassistant.config_entries import ConfigEntryState
|
from homeassistant.config_entries import ConfigEntryState
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
@@ -427,7 +430,7 @@ async def test_migration_from_v1_to_v2(
|
|||||||
entity = entity_registry.async_get_or_create(
|
entity = entity_registry.async_get_or_create(
|
||||||
"conversation",
|
"conversation",
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
"mock_config_entry.entry_id",
|
mock_config_entry.entry_id,
|
||||||
config_entry=mock_config_entry,
|
config_entry=mock_config_entry,
|
||||||
device_id=device.id,
|
device_id=device.id,
|
||||||
suggested_object_id="google_generative_ai_conversation",
|
suggested_object_id="google_generative_ai_conversation",
|
||||||
@@ -445,7 +448,7 @@ async def test_migration_from_v1_to_v2(
|
|||||||
|
|
||||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||||
assert subentry.unique_id is None
|
assert subentry.unique_id is None
|
||||||
assert subentry.title == "Google Conversation"
|
assert subentry.title == DEFAULT_CONVERSATION_NAME
|
||||||
assert subentry.subentry_type == "conversation"
|
assert subentry.subentry_type == "conversation"
|
||||||
assert subentry.data == OPTIONS
|
assert subentry.data == OPTIONS
|
||||||
|
|
||||||
|
@@ -122,7 +122,9 @@ async def mock_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
|
|||||||
async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
|
async def mock_config_entry_setup(hass: HomeAssistant, config: dict[str, Any]) -> None:
|
||||||
"""Mock config entry setup."""
|
"""Mock config entry setup."""
|
||||||
default_config = {tts.CONF_LANG: "en-US"}
|
default_config = {tts.CONF_LANG: "en-US"}
|
||||||
config_entry = MockConfigEntry(domain=DOMAIN, data=default_config | config)
|
config_entry = MockConfigEntry(
|
||||||
|
domain=DOMAIN, data=default_config | config, version=2
|
||||||
|
)
|
||||||
|
|
||||||
client_mock = Mock()
|
client_mock = Mock()
|
||||||
client_mock.models.get = None
|
client_mock.models.get = None
|
||||||
|
Reference in New Issue
Block a user