AI task generate_text -> generate_data (#147370)

This commit is contained in:
Paulus Schoutsen
2025-06-24 07:12:29 -04:00
committed by GitHub
parent 38c7eaf70a
commit 63ac14a19b
14 changed files with 104 additions and 100 deletions

View File

@@ -24,20 +24,20 @@ from .const import (
DATA_COMPONENT,
DATA_PREFERENCES,
DOMAIN,
SERVICE_GENERATE_TEXT,
SERVICE_GENERATE_DATA,
AITaskEntityFeature,
)
from .entity import AITaskEntity
from .http import async_setup as async_setup_http
from .task import GenTextTask, GenTextTaskResult, async_generate_text
from .task import GenDataTask, GenDataTaskResult, async_generate_data
__all__ = [
"DOMAIN",
"AITaskEntity",
"AITaskEntityFeature",
"GenTextTask",
"GenTextTaskResult",
"async_generate_text",
"GenDataTask",
"GenDataTaskResult",
"async_generate_data",
"async_setup",
"async_setup_entry",
"async_unload_entry",
@@ -57,8 +57,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async_setup_http(hass)
hass.services.async_register(
DOMAIN,
SERVICE_GENERATE_TEXT,
async_service_generate_text,
SERVICE_GENERATE_DATA,
async_service_generate_data,
schema=vol.Schema(
{
vol.Required(ATTR_TASK_NAME): cv.string,
@@ -82,18 +82,18 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
async def async_service_generate_text(call: ServiceCall) -> ServiceResponse:
async def async_service_generate_data(call: ServiceCall) -> ServiceResponse:
"""Run the run task service."""
result = await async_generate_text(hass=call.hass, **call.data)
return result.as_dict() # type: ignore[return-value]
result = await async_generate_data(hass=call.hass, **call.data)
return result.as_dict()
class AITaskPreferences:
"""AI Task preferences."""
KEYS = ("gen_text_entity_id",)
KEYS = ("gen_data_entity_id",)
gen_text_entity_id: str | None = None
gen_data_entity_id: str | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the preferences."""
@@ -113,11 +113,11 @@ class AITaskPreferences:
def async_set_preferences(
self,
*,
gen_text_entity_id: str | None | UndefinedType = UNDEFINED,
gen_data_entity_id: str | None | UndefinedType = UNDEFINED,
) -> None:
"""Set the preferences."""
changed = False
for key, value in (("gen_text_entity_id", gen_text_entity_id),):
for key, value in (("gen_data_entity_id", gen_data_entity_id),):
if value is not UNDEFINED:
if getattr(self, key) != value:
setattr(self, key, value)

View File

@@ -17,7 +17,7 @@ DOMAIN = "ai_task"
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
SERVICE_GENERATE_TEXT = "generate_text"
SERVICE_GENERATE_DATA = "generate_data"
ATTR_INSTRUCTIONS: Final = "instructions"
ATTR_TASK_NAME: Final = "task_name"
@@ -30,5 +30,5 @@ DEFAULT_SYSTEM_PROMPT = (
class AITaskEntityFeature(IntFlag):
"""Supported features of the AI task entity."""
GENERATE_TEXT = 1
"""Generate text based on instructions."""
GENERATE_DATA = 1
"""Generate data based on instructions."""

View File

@@ -18,7 +18,7 @@ from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import dt as dt_util
from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN, AITaskEntityFeature
from .task import GenTextTask, GenTextTaskResult
from .task import GenDataTask, GenDataTaskResult
class AITaskEntity(RestoreEntity):
@@ -56,7 +56,7 @@ class AITaskEntity(RestoreEntity):
@contextlib.asynccontextmanager
async def _async_get_ai_task_chat_log(
self,
task: GenTextTask,
task: GenDataTask,
) -> AsyncGenerator[ChatLog]:
"""Context manager used to manage the ChatLog used during an AI Task."""
# pylint: disable-next=contextmanager-generator-missing-cleanup
@@ -84,20 +84,20 @@ class AITaskEntity(RestoreEntity):
yield chat_log
@final
async def internal_async_generate_text(
async def internal_async_generate_data(
self,
task: GenTextTask,
) -> GenTextTaskResult:
"""Run a gen text task."""
task: GenDataTask,
) -> GenDataTaskResult:
"""Run a gen data task."""
self.__last_activity = dt_util.utcnow().isoformat()
self.async_write_ha_state()
async with self._async_get_ai_task_chat_log(task) as chat_log:
return await self._async_generate_text(task, chat_log)
return await self._async_generate_data(task, chat_log)
async def _async_generate_text(
async def _async_generate_data(
self,
task: GenTextTask,
task: GenDataTask,
chat_log: ChatLog,
) -> GenTextTaskResult:
"""Handle a gen text task."""
) -> GenDataTaskResult:
"""Handle a gen data task."""
raise NotImplementedError

View File

@@ -36,7 +36,7 @@ def websocket_get_preferences(
@websocket_api.websocket_command(
{
vol.Required("type"): "ai_task/preferences/set",
vol.Optional("gen_text_entity_id"): vol.Any(str, None),
vol.Optional("gen_data_entity_id"): vol.Any(str, None),
}
)
@websocket_api.require_admin

View File

@@ -1,6 +1,6 @@
{
"services": {
"generate_text": {
"generate_data": {
"service": "mdi:file-star-four-points-outline"
}
}

View File

@@ -1,4 +1,4 @@
generate_text:
generate_data:
fields:
task_name:
example: "home summary"
@@ -16,4 +16,4 @@ generate_text:
entity:
domain: ai_task
supported_features:
- ai_task.AITaskEntityFeature.GENERATE_TEXT
- ai_task.AITaskEntityFeature.GENERATE_DATA

View File

@@ -1,8 +1,8 @@
{
"services": {
"generate_text": {
"name": "Generate text",
"description": "Use AI to run a task that generates text.",
"generate_data": {
"name": "Generate data",
"description": "Uses AI to run a task that generates data.",
"fields": {
"task_name": {
"name": "Task name",

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
@@ -10,16 +11,16 @@ from homeassistant.exceptions import HomeAssistantError
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
async def async_generate_text(
async def async_generate_data(
hass: HomeAssistant,
*,
task_name: str,
entity_id: str | None = None,
instructions: str,
) -> GenTextTaskResult:
) -> GenDataTaskResult:
"""Run a task in the AI Task integration."""
if entity_id is None:
entity_id = hass.data[DATA_PREFERENCES].gen_text_entity_id
entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
if entity_id is None:
raise HomeAssistantError("No entity_id provided and no preferred entity set")
@@ -28,13 +29,13 @@ async def async_generate_text(
if entity is None:
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
if AITaskEntityFeature.GENERATE_TEXT not in entity.supported_features:
if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
raise HomeAssistantError(
f"AI Task entity {entity_id} does not support generating text"
f"AI Task entity {entity_id} does not support generating data"
)
return await entity.internal_async_generate_text(
GenTextTask(
return await entity.internal_async_generate_data(
GenDataTask(
name=task_name,
instructions=instructions,
)
@@ -42,8 +43,8 @@ async def async_generate_text(
@dataclass(slots=True)
class GenTextTask:
"""Gen text task to be processed."""
class GenDataTask:
"""Gen data task to be processed."""
name: str
"""Name of the task."""
@@ -53,22 +54,22 @@ class GenTextTask:
def __str__(self) -> str:
"""Return task as a string."""
return f"<GenTextTask {self.name}: {id(self)}>"
return f"<GenDataTask {self.name}: {id(self)}>"
@dataclass(slots=True)
class GenTextTaskResult:
"""Result of gen text task."""
class GenDataTaskResult:
"""Result of gen data task."""
conversation_id: str
"""Unique identifier for the conversation."""
text: str
"""Generated text."""
data: Any
"""Data generated by the task."""
def as_dict(self) -> dict[str, str]:
def as_dict(self) -> dict[str, Any]:
"""Return result as a dict."""
return {
"conversation_id": self.conversation_id,
"text": self.text,
"data": self.data,
}