mirror of
https://github.com/home-assistant/core.git
synced 2025-08-01 03:35:09 +02:00
AI task generate_text -> generate_data (#147370)
This commit is contained in:
@@ -24,20 +24,20 @@ from .const import (
|
||||
DATA_COMPONENT,
|
||||
DATA_PREFERENCES,
|
||||
DOMAIN,
|
||||
SERVICE_GENERATE_TEXT,
|
||||
SERVICE_GENERATE_DATA,
|
||||
AITaskEntityFeature,
|
||||
)
|
||||
from .entity import AITaskEntity
|
||||
from .http import async_setup as async_setup_http
|
||||
from .task import GenTextTask, GenTextTaskResult, async_generate_text
|
||||
from .task import GenDataTask, GenDataTaskResult, async_generate_data
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"AITaskEntity",
|
||||
"AITaskEntityFeature",
|
||||
"GenTextTask",
|
||||
"GenTextTaskResult",
|
||||
"async_generate_text",
|
||||
"GenDataTask",
|
||||
"GenDataTaskResult",
|
||||
"async_generate_data",
|
||||
"async_setup",
|
||||
"async_setup_entry",
|
||||
"async_unload_entry",
|
||||
@@ -57,8 +57,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
async_setup_http(hass)
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
SERVICE_GENERATE_TEXT,
|
||||
async_service_generate_text,
|
||||
SERVICE_GENERATE_DATA,
|
||||
async_service_generate_data,
|
||||
schema=vol.Schema(
|
||||
{
|
||||
vol.Required(ATTR_TASK_NAME): cv.string,
|
||||
@@ -82,18 +82,18 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
return await hass.data[DATA_COMPONENT].async_unload_entry(entry)
|
||||
|
||||
|
||||
async def async_service_generate_text(call: ServiceCall) -> ServiceResponse:
|
||||
async def async_service_generate_data(call: ServiceCall) -> ServiceResponse:
|
||||
"""Run the run task service."""
|
||||
result = await async_generate_text(hass=call.hass, **call.data)
|
||||
return result.as_dict() # type: ignore[return-value]
|
||||
result = await async_generate_data(hass=call.hass, **call.data)
|
||||
return result.as_dict()
|
||||
|
||||
|
||||
class AITaskPreferences:
|
||||
"""AI Task preferences."""
|
||||
|
||||
KEYS = ("gen_text_entity_id",)
|
||||
KEYS = ("gen_data_entity_id",)
|
||||
|
||||
gen_text_entity_id: str | None = None
|
||||
gen_data_entity_id: str | None = None
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the preferences."""
|
||||
@@ -113,11 +113,11 @@ class AITaskPreferences:
|
||||
def async_set_preferences(
|
||||
self,
|
||||
*,
|
||||
gen_text_entity_id: str | None | UndefinedType = UNDEFINED,
|
||||
gen_data_entity_id: str | None | UndefinedType = UNDEFINED,
|
||||
) -> None:
|
||||
"""Set the preferences."""
|
||||
changed = False
|
||||
for key, value in (("gen_text_entity_id", gen_text_entity_id),):
|
||||
for key, value in (("gen_data_entity_id", gen_data_entity_id),):
|
||||
if value is not UNDEFINED:
|
||||
if getattr(self, key) != value:
|
||||
setattr(self, key, value)
|
||||
|
@@ -17,7 +17,7 @@ DOMAIN = "ai_task"
|
||||
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
|
||||
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
|
||||
|
||||
SERVICE_GENERATE_TEXT = "generate_text"
|
||||
SERVICE_GENERATE_DATA = "generate_data"
|
||||
|
||||
ATTR_INSTRUCTIONS: Final = "instructions"
|
||||
ATTR_TASK_NAME: Final = "task_name"
|
||||
@@ -30,5 +30,5 @@ DEFAULT_SYSTEM_PROMPT = (
|
||||
class AITaskEntityFeature(IntFlag):
|
||||
"""Supported features of the AI task entity."""
|
||||
|
||||
GENERATE_TEXT = 1
|
||||
"""Generate text based on instructions."""
|
||||
GENERATE_DATA = 1
|
||||
"""Generate data based on instructions."""
|
||||
|
@@ -18,7 +18,7 @@ from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .const import DEFAULT_SYSTEM_PROMPT, DOMAIN, AITaskEntityFeature
|
||||
from .task import GenTextTask, GenTextTaskResult
|
||||
from .task import GenDataTask, GenDataTaskResult
|
||||
|
||||
|
||||
class AITaskEntity(RestoreEntity):
|
||||
@@ -56,7 +56,7 @@ class AITaskEntity(RestoreEntity):
|
||||
@contextlib.asynccontextmanager
|
||||
async def _async_get_ai_task_chat_log(
|
||||
self,
|
||||
task: GenTextTask,
|
||||
task: GenDataTask,
|
||||
) -> AsyncGenerator[ChatLog]:
|
||||
"""Context manager used to manage the ChatLog used during an AI Task."""
|
||||
# pylint: disable-next=contextmanager-generator-missing-cleanup
|
||||
@@ -84,20 +84,20 @@ class AITaskEntity(RestoreEntity):
|
||||
yield chat_log
|
||||
|
||||
@final
|
||||
async def internal_async_generate_text(
|
||||
async def internal_async_generate_data(
|
||||
self,
|
||||
task: GenTextTask,
|
||||
) -> GenTextTaskResult:
|
||||
"""Run a gen text task."""
|
||||
task: GenDataTask,
|
||||
) -> GenDataTaskResult:
|
||||
"""Run a gen data task."""
|
||||
self.__last_activity = dt_util.utcnow().isoformat()
|
||||
self.async_write_ha_state()
|
||||
async with self._async_get_ai_task_chat_log(task) as chat_log:
|
||||
return await self._async_generate_text(task, chat_log)
|
||||
return await self._async_generate_data(task, chat_log)
|
||||
|
||||
async def _async_generate_text(
|
||||
async def _async_generate_data(
|
||||
self,
|
||||
task: GenTextTask,
|
||||
task: GenDataTask,
|
||||
chat_log: ChatLog,
|
||||
) -> GenTextTaskResult:
|
||||
"""Handle a gen text task."""
|
||||
) -> GenDataTaskResult:
|
||||
"""Handle a gen data task."""
|
||||
raise NotImplementedError
|
||||
|
@@ -36,7 +36,7 @@ def websocket_get_preferences(
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "ai_task/preferences/set",
|
||||
vol.Optional("gen_text_entity_id"): vol.Any(str, None),
|
||||
vol.Optional("gen_data_entity_id"): vol.Any(str, None),
|
||||
}
|
||||
)
|
||||
@websocket_api.require_admin
|
||||
|
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"services": {
|
||||
"generate_text": {
|
||||
"generate_data": {
|
||||
"service": "mdi:file-star-four-points-outline"
|
||||
}
|
||||
}
|
||||
|
@@ -1,4 +1,4 @@
|
||||
generate_text:
|
||||
generate_data:
|
||||
fields:
|
||||
task_name:
|
||||
example: "home summary"
|
||||
@@ -16,4 +16,4 @@ generate_text:
|
||||
entity:
|
||||
domain: ai_task
|
||||
supported_features:
|
||||
- ai_task.AITaskEntityFeature.GENERATE_TEXT
|
||||
- ai_task.AITaskEntityFeature.GENERATE_DATA
|
||||
|
@@ -1,8 +1,8 @@
|
||||
{
|
||||
"services": {
|
||||
"generate_text": {
|
||||
"name": "Generate text",
|
||||
"description": "Use AI to run a task that generates text.",
|
||||
"generate_data": {
|
||||
"name": "Generate data",
|
||||
"description": "Uses AI to run a task that generates data.",
|
||||
"fields": {
|
||||
"task_name": {
|
||||
"name": "Task name",
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@@ -10,16 +11,16 @@ from homeassistant.exceptions import HomeAssistantError
|
||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||
|
||||
|
||||
async def async_generate_text(
|
||||
async def async_generate_data(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
task_name: str,
|
||||
entity_id: str | None = None,
|
||||
instructions: str,
|
||||
) -> GenTextTaskResult:
|
||||
) -> GenDataTaskResult:
|
||||
"""Run a task in the AI Task integration."""
|
||||
if entity_id is None:
|
||||
entity_id = hass.data[DATA_PREFERENCES].gen_text_entity_id
|
||||
entity_id = hass.data[DATA_PREFERENCES].gen_data_entity_id
|
||||
|
||||
if entity_id is None:
|
||||
raise HomeAssistantError("No entity_id provided and no preferred entity set")
|
||||
@@ -28,13 +29,13 @@ async def async_generate_text(
|
||||
if entity is None:
|
||||
raise HomeAssistantError(f"AI Task entity {entity_id} not found")
|
||||
|
||||
if AITaskEntityFeature.GENERATE_TEXT not in entity.supported_features:
|
||||
if AITaskEntityFeature.GENERATE_DATA not in entity.supported_features:
|
||||
raise HomeAssistantError(
|
||||
f"AI Task entity {entity_id} does not support generating text"
|
||||
f"AI Task entity {entity_id} does not support generating data"
|
||||
)
|
||||
|
||||
return await entity.internal_async_generate_text(
|
||||
GenTextTask(
|
||||
return await entity.internal_async_generate_data(
|
||||
GenDataTask(
|
||||
name=task_name,
|
||||
instructions=instructions,
|
||||
)
|
||||
@@ -42,8 +43,8 @@ async def async_generate_text(
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GenTextTask:
|
||||
"""Gen text task to be processed."""
|
||||
class GenDataTask:
|
||||
"""Gen data task to be processed."""
|
||||
|
||||
name: str
|
||||
"""Name of the task."""
|
||||
@@ -53,22 +54,22 @@ class GenTextTask:
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return task as a string."""
|
||||
return f"<GenTextTask {self.name}: {id(self)}>"
|
||||
return f"<GenDataTask {self.name}: {id(self)}>"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GenTextTaskResult:
|
||||
"""Result of gen text task."""
|
||||
class GenDataTaskResult:
|
||||
"""Result of gen data task."""
|
||||
|
||||
conversation_id: str
|
||||
"""Unique identifier for the conversation."""
|
||||
|
||||
text: str
|
||||
"""Generated text."""
|
||||
data: Any
|
||||
"""Data generated by the task."""
|
||||
|
||||
def as_dict(self) -> dict[str, str]:
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Return result as a dict."""
|
||||
return {
|
||||
"conversation_id": self.conversation_id,
|
||||
"text": self.text,
|
||||
"data": self.data,
|
||||
}
|
||||
|
Reference in New Issue
Block a user