mirror of
https://github.com/home-assistant/core.git
synced 2025-09-09 14:51:34 +02:00
Bump google-genai to 1.29.0 (#150225)
This commit is contained in:
@@ -124,7 +124,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
f"Error generating content due to content violations, reason: {response.prompt_feedback.block_reason_message}"
|
f"Error generating content due to content violations, reason: {response.prompt_feedback.block_reason_message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not response.candidates[0].content.parts:
|
if (
|
||||||
|
not response.candidates
|
||||||
|
or not response.candidates[0].content
|
||||||
|
or not response.candidates[0].content.parts
|
||||||
|
):
|
||||||
raise HomeAssistantError("Unknown error generating content")
|
raise HomeAssistantError("Unknown error generating content")
|
||||||
|
|
||||||
return {"text": response.text}
|
return {"text": response.text}
|
||||||
|
@@ -377,7 +377,7 @@ async def google_generative_ai_config_option_schema(
|
|||||||
value=api_model.name,
|
value=api_model.name,
|
||||||
)
|
)
|
||||||
for api_model in sorted(
|
for api_model in sorted(
|
||||||
api_models, key=lambda x: x.name.lstrip("models/") or ""
|
api_models, key=lambda x: (x.name or "").lstrip("models/")
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
api_model.name
|
api_model.name
|
||||||
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import codecs
|
import codecs
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, AsyncIterator, Callable
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -15,6 +15,7 @@ from google.genai.errors import APIError, ClientError
|
|||||||
from google.genai.types import (
|
from google.genai.types import (
|
||||||
AutomaticFunctionCallingConfig,
|
AutomaticFunctionCallingConfig,
|
||||||
Content,
|
Content,
|
||||||
|
ContentDict,
|
||||||
File,
|
File,
|
||||||
FileState,
|
FileState,
|
||||||
FunctionDeclaration,
|
FunctionDeclaration,
|
||||||
@@ -23,9 +24,11 @@ from google.genai.types import (
|
|||||||
GoogleSearch,
|
GoogleSearch,
|
||||||
HarmCategory,
|
HarmCategory,
|
||||||
Part,
|
Part,
|
||||||
|
PartUnionDict,
|
||||||
SafetySetting,
|
SafetySetting,
|
||||||
Schema,
|
Schema,
|
||||||
Tool,
|
Tool,
|
||||||
|
ToolListUnion,
|
||||||
)
|
)
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
from voluptuous_openapi import convert
|
from voluptuous_openapi import convert
|
||||||
@@ -237,7 +240,7 @@ def _convert_content(
|
|||||||
|
|
||||||
|
|
||||||
async def _transform_stream(
|
async def _transform_stream(
|
||||||
result: AsyncGenerator[GenerateContentResponse],
|
result: AsyncIterator[GenerateContentResponse],
|
||||||
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
|
||||||
new_message = True
|
new_message = True
|
||||||
try:
|
try:
|
||||||
@@ -342,7 +345,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
"""Generate an answer for the chat log."""
|
"""Generate an answer for the chat log."""
|
||||||
options = self.subentry.data
|
options = self.subentry.data
|
||||||
|
|
||||||
tools: list[Tool | Callable[..., Any]] | None = None
|
tools: ToolListUnion | None = None
|
||||||
if chat_log.llm_api:
|
if chat_log.llm_api:
|
||||||
tools = [
|
tools = [
|
||||||
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
_format_tool(tool, chat_log.llm_api.custom_serializer)
|
||||||
@@ -373,7 +376,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
else:
|
else:
|
||||||
raise HomeAssistantError("Invalid prompt content")
|
raise HomeAssistantError("Invalid prompt content")
|
||||||
|
|
||||||
messages: list[Content] = []
|
messages: list[Content | ContentDict] = []
|
||||||
|
|
||||||
# Google groups tool results, we do not. Group them before sending.
|
# Google groups tool results, we do not. Group them before sending.
|
||||||
tool_results: list[conversation.ToolResultContent] = []
|
tool_results: list[conversation.ToolResultContent] = []
|
||||||
@@ -400,7 +403,10 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
# The SDK requires the first message to be a user message
|
# The SDK requires the first message to be a user message
|
||||||
# This is not the case if user used `start_conversation`
|
# This is not the case if user used `start_conversation`
|
||||||
# Workaround from https://github.com/googleapis/python-genai/issues/529#issuecomment-2740964537
|
# Workaround from https://github.com/googleapis/python-genai/issues/529#issuecomment-2740964537
|
||||||
if messages and messages[0].role != "user":
|
if messages and (
|
||||||
|
(isinstance(messages[0], Content) and messages[0].role != "user")
|
||||||
|
or (isinstance(messages[0], dict) and messages[0]["role"] != "user")
|
||||||
|
):
|
||||||
messages.insert(
|
messages.insert(
|
||||||
0,
|
0,
|
||||||
Content(role="user", parts=[Part.from_text(text=" ")]),
|
Content(role="user", parts=[Part.from_text(text=" ")]),
|
||||||
@@ -440,14 +446,14 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
)
|
)
|
||||||
user_message = chat_log.content[-1]
|
user_message = chat_log.content[-1]
|
||||||
assert isinstance(user_message, conversation.UserContent)
|
assert isinstance(user_message, conversation.UserContent)
|
||||||
chat_request: str | list[Part] = user_message.content
|
chat_request: list[PartUnionDict] = [user_message.content]
|
||||||
if user_message.attachments:
|
if user_message.attachments:
|
||||||
files = await async_prepare_files_for_prompt(
|
files = await async_prepare_files_for_prompt(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._genai_client,
|
self._genai_client,
|
||||||
[a.path for a in user_message.attachments],
|
[a.path for a in user_message.attachments],
|
||||||
)
|
)
|
||||||
chat_request = [chat_request, *files]
|
chat_request = [*chat_request, *files]
|
||||||
|
|
||||||
# To prevent infinite loops, we limit the number of iterations
|
# To prevent infinite loops, we limit the number of iterations
|
||||||
for _iteration in range(MAX_TOOL_ITERATIONS):
|
for _iteration in range(MAX_TOOL_ITERATIONS):
|
||||||
@@ -464,15 +470,17 @@ class GoogleGenerativeAILLMBaseEntity(Entity):
|
|||||||
error = ERROR_GETTING_RESPONSE
|
error = ERROR_GETTING_RESPONSE
|
||||||
raise HomeAssistantError(error) from err
|
raise HomeAssistantError(error) from err
|
||||||
|
|
||||||
chat_request = _create_google_tool_response_parts(
|
chat_request = list(
|
||||||
[
|
_create_google_tool_response_parts(
|
||||||
content
|
[
|
||||||
async for content in chat_log.async_add_delta_content_stream(
|
content
|
||||||
self.entity_id,
|
async for content in chat_log.async_add_delta_content_stream(
|
||||||
_transform_stream(chat_response_generator),
|
self.entity_id,
|
||||||
)
|
_transform_stream(chat_response_generator),
|
||||||
if isinstance(content, conversation.ToolResultContent)
|
)
|
||||||
]
|
if isinstance(content, conversation.ToolResultContent)
|
||||||
|
]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not chat_log.unresponded_tool_results:
|
if not chat_log.unresponded_tool_results:
|
||||||
@@ -559,13 +567,13 @@ async def async_prepare_files_for_prompt(
|
|||||||
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)
|
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)
|
||||||
|
|
||||||
uploaded_file = await client.aio.files.get(
|
uploaded_file = await client.aio.files.get(
|
||||||
name=uploaded_file.name,
|
name=uploaded_file.name or "",
|
||||||
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
|
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
|
||||||
)
|
)
|
||||||
|
|
||||||
if uploaded_file.state == FileState.FAILED:
|
if uploaded_file.state == FileState.FAILED:
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
|
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message if uploaded_file.error else 'unknown'}"
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_parts = await hass.async_add_executor_job(upload_files)
|
prompt_parts = await hass.async_add_executor_job(upload_files)
|
||||||
|
@@ -8,5 +8,5 @@
|
|||||||
"documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation",
|
"documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation",
|
||||||
"integration_type": "service",
|
"integration_type": "service",
|
||||||
"iot_class": "cloud_polling",
|
"iot_class": "cloud_polling",
|
||||||
"requirements": ["google-genai==1.7.0"]
|
"requirements": ["google-genai==1.29.0"]
|
||||||
}
|
}
|
||||||
|
@@ -146,15 +146,41 @@ class GoogleGenerativeAITextToSpeechEntity(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _extract_audio_parts(
|
||||||
|
response: types.GenerateContentResponse,
|
||||||
|
) -> tuple[bytes, str]:
|
||||||
|
if (
|
||||||
|
not response.candidates
|
||||||
|
or not response.candidates[0].content
|
||||||
|
or not response.candidates[0].content.parts
|
||||||
|
or not response.candidates[0].content.parts[0].inline_data
|
||||||
|
):
|
||||||
|
raise ValueError("No content returned from TTS generation")
|
||||||
|
|
||||||
|
data = response.candidates[0].content.parts[0].inline_data.data
|
||||||
|
mime_type = response.candidates[0].content.parts[0].inline_data.mime_type
|
||||||
|
|
||||||
|
if not isinstance(data, bytes):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected bytes for audio data, got {type(data).__name__}"
|
||||||
|
)
|
||||||
|
if not isinstance(mime_type, str):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected str for mime_type, got {type(mime_type).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data, mime_type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._genai_client.aio.models.generate_content(
|
response = await self._genai_client.aio.models.generate_content(
|
||||||
model=self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_TTS_MODEL),
|
model=self.subentry.data.get(CONF_CHAT_MODEL, RECOMMENDED_TTS_MODEL),
|
||||||
contents=message,
|
contents=message,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
data = response.candidates[0].content.parts[0].inline_data.data
|
|
||||||
mime_type = response.candidates[0].content.parts[0].inline_data.mime_type
|
data, mime_type = _extract_audio_parts(response)
|
||||||
except (APIError, ClientError, ValueError) as exc:
|
except (APIError, ClientError, ValueError, TypeError) as exc:
|
||||||
LOGGER.error("Error during TTS: %s", exc, exc_info=True)
|
LOGGER.error("Error during TTS: %s", exc, exc_info=True)
|
||||||
raise HomeAssistantError(exc) from exc
|
raise HomeAssistantError(exc) from exc
|
||||||
return "wav", convert_to_wav(data, mime_type)
|
return "wav", convert_to_wav(data, mime_type)
|
||||||
|
2
requirements_all.txt
generated
2
requirements_all.txt
generated
@@ -1057,7 +1057,7 @@ google-cloud-speech==2.31.1
|
|||||||
google-cloud-texttospeech==2.25.1
|
google-cloud-texttospeech==2.25.1
|
||||||
|
|
||||||
# homeassistant.components.google_generative_ai_conversation
|
# homeassistant.components.google_generative_ai_conversation
|
||||||
google-genai==1.7.0
|
google-genai==1.29.0
|
||||||
|
|
||||||
# homeassistant.components.google_travel_time
|
# homeassistant.components.google_travel_time
|
||||||
google-maps-routing==0.6.15
|
google-maps-routing==0.6.15
|
||||||
|
2
requirements_test_all.txt
generated
2
requirements_test_all.txt
generated
@@ -924,7 +924,7 @@ google-cloud-speech==2.31.1
|
|||||||
google-cloud-texttospeech==2.25.1
|
google-cloud-texttospeech==2.25.1
|
||||||
|
|
||||||
# homeassistant.components.google_generative_ai_conversation
|
# homeassistant.components.google_generative_ai_conversation
|
||||||
google-genai==1.7.0
|
google-genai==1.29.0
|
||||||
|
|
||||||
# homeassistant.components.google_travel_time
|
# homeassistant.components.google_travel_time
|
||||||
google-maps-routing==0.6.15
|
google-maps-routing==0.6.15
|
||||||
|
@@ -1,43 +1,16 @@
|
|||||||
"""Tests for the Google Generative AI Conversation integration."""
|
"""Tests for the Google Generative AI Conversation integration."""
|
||||||
|
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
from google.genai.errors import APIError, ClientError
|
from google.genai.errors import APIError, ClientError
|
||||||
import httpx
|
|
||||||
|
|
||||||
API_ERROR_500 = APIError(
|
API_ERROR_500 = APIError(
|
||||||
500,
|
500,
|
||||||
Mock(
|
{"message": "Internal Server Error", "status": "internal-error"},
|
||||||
__class__=httpx.Response,
|
|
||||||
json=Mock(
|
|
||||||
return_value={
|
|
||||||
"message": "Internal Server Error",
|
|
||||||
"status": "internal-error",
|
|
||||||
}
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
CLIENT_ERROR_BAD_REQUEST = ClientError(
|
CLIENT_ERROR_BAD_REQUEST = ClientError(
|
||||||
400,
|
400,
|
||||||
Mock(
|
{"message": "Bad Request", "status": "invalid-argument"},
|
||||||
__class__=httpx.Response,
|
|
||||||
json=Mock(
|
|
||||||
return_value={
|
|
||||||
"message": "Bad Request",
|
|
||||||
"status": "invalid-argument",
|
|
||||||
}
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
CLIENT_ERROR_API_KEY_INVALID = ClientError(
|
CLIENT_ERROR_API_KEY_INVALID = ClientError(
|
||||||
400,
|
400,
|
||||||
Mock(
|
{"message": "'reason': API_KEY_INVALID", "status": "unauthorized"},
|
||||||
__class__=httpx.Response,
|
|
||||||
json=Mock(
|
|
||||||
return_value={
|
|
||||||
"message": "'reason': API_KEY_INVALID",
|
|
||||||
"status": "unauthorized",
|
|
||||||
}
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
@@ -128,8 +128,14 @@
|
|||||||
dict({
|
dict({
|
||||||
'contents': list([
|
'contents': list([
|
||||||
'Describe this image from my doorbell camera',
|
'Describe this image from my doorbell camera',
|
||||||
File(name='doorbell_snapshot.jpg', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
|
File(
|
||||||
File(name='context.txt', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.PROCESSING: 'PROCESSING'>, source=None, video_metadata=None, error=None),
|
name='doorbell_snapshot.jpg',
|
||||||
|
state=<FileState.ACTIVE: 'ACTIVE'>
|
||||||
|
),
|
||||||
|
File(
|
||||||
|
name='context.txt',
|
||||||
|
state=<FileState.PROCESSING: 'PROCESSING'>
|
||||||
|
),
|
||||||
]),
|
]),
|
||||||
'model': 'models/gemini-2.5-flash',
|
'model': 'models/gemini-2.5-flash',
|
||||||
}),
|
}),
|
||||||
@@ -145,8 +151,14 @@
|
|||||||
dict({
|
dict({
|
||||||
'contents': list([
|
'contents': list([
|
||||||
'Describe this image from my doorbell camera',
|
'Describe this image from my doorbell camera',
|
||||||
File(name='doorbell_snapshot.jpg', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
|
File(
|
||||||
File(name='context.txt', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=<FileState.ACTIVE: 'ACTIVE'>, source=None, video_metadata=None, error=None),
|
name='doorbell_snapshot.jpg',
|
||||||
|
state=<FileState.ACTIVE: 'ACTIVE'>
|
||||||
|
),
|
||||||
|
File(
|
||||||
|
name='context.txt',
|
||||||
|
state=<FileState.ACTIVE: 'ACTIVE'>
|
||||||
|
),
|
||||||
]),
|
]),
|
||||||
'model': 'models/gemini-2.5-flash',
|
'model': 'models/gemini-2.5-flash',
|
||||||
}),
|
}),
|
||||||
|
@@ -195,10 +195,13 @@ async def test_function_call(
|
|||||||
"response": {
|
"response": {
|
||||||
"result": "Test response",
|
"result": "Test response",
|
||||||
},
|
},
|
||||||
|
"scheduling": None,
|
||||||
|
"will_continue": None,
|
||||||
},
|
},
|
||||||
"inline_data": None,
|
"inline_data": None,
|
||||||
"text": None,
|
"text": None,
|
||||||
"thought": None,
|
"thought": None,
|
||||||
|
"thought_signature": None,
|
||||||
"video_metadata": None,
|
"video_metadata": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -37,7 +37,7 @@ from tests.common import MockConfigEntry, async_mock_service
|
|||||||
from tests.components.tts.common import retrieve_media
|
from tests.components.tts.common import retrieve_media
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
API_ERROR_500 = APIError("test", response=MagicMock())
|
API_ERROR_500 = APIError("test", response_json={})
|
||||||
TEST_CHAT_MODEL = "models/some-tts-model"
|
TEST_CHAT_MODEL = "models/some-tts-model"
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user