AI Task to support LLM APIs

This commit is contained in:
Paulus Schoutsen
2025-08-24 07:51:06 +02:00
parent 03ca164fb3
commit b519106bc4
4 changed files with 22 additions and 3 deletions

View File

@@ -77,6 +77,7 @@ class AITaskEntity(RestoreEntity):
device_id=None,
),
user_llm_prompt=DEFAULT_SYSTEM_PROMPT,
user_llm_hass_api=task.llm_api,
)
chat_log.async_add_user_content(

View File

@@ -13,6 +13,7 @@ import voluptuous as vol
from homeassistant.components import camera, conversation, media_source
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.chat_session import async_get_chat_session
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
@@ -37,6 +38,7 @@ async def async_generate_data(
instructions: str,
structure: vol.Schema | None = None,
attachments: list[dict] | None = None,
llm_api: llm.API | None = None,
) -> GenDataTaskResult:
"""Run a task in the AI Task integration."""
if entity_id is None:
@@ -126,6 +128,7 @@ async def async_generate_data(
instructions=instructions,
structure=structure,
attachments=resolved_attachments or None,
llm_api=llm_api,
),
)
@@ -146,6 +149,9 @@ class GenDataTask:
attachments: list[conversation.Attachment] | None = None
"""List of attachments to go along the instructions."""
llm_api: llm.API | None = None
"""API to provide to the LLM."""
def __str__(self) -> str:
"""Return task as a string."""
return f"<GenDataTask {self.name}: {id(self)}>"

View File

@@ -507,14 +507,18 @@ class ChatLog:
async def async_provide_llm_data(
self,
llm_context: llm.LLMContext,
user_llm_hass_api: str | list[str] | None = None,
user_llm_hass_api: str | list[str] | llm.API | None = None,
user_llm_prompt: str | None = None,
user_extra_system_prompt: str | None = None,
) -> None:
"""Set the LLM system prompt."""
llm_api: llm.APIInstance | None = None
if user_llm_hass_api:
if user_llm_hass_api is None:
pass
elif isinstance(user_llm_hass_api, llm.API):
llm_api = await user_llm_hass_api.async_get_api_instance(llm_context)
else:
try:
llm_api = await llm.async_get_api(
self.hass,

View File

@@ -15,7 +15,7 @@ from homeassistant.components.conversation import async_get_chat_log
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session
from homeassistant.helpers import chat_session, llm
from homeassistant.util import dt as dt_util
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
@@ -73,10 +73,12 @@ async def test_generate_data_preferred_entity(
assert state is not None
assert state.state == STATE_UNKNOWN
llm_api = llm.AssistAPI(hass)
result = await async_generate_data(
hass,
task_name="Test Task",
instructions="Test prompt",
llm_api=llm_api,
)
assert result.data == "Mock result"
as_dict = result.as_dict()
@@ -86,6 +88,12 @@ async def test_generate_data_preferred_entity(
assert state is not None
assert state.state != STATE_UNKNOWN
with (
chat_session.async_get_chat_session(hass, result.conversation_id) as session,
async_get_chat_log(hass, session) as chat_log,
):
assert chat_log.llm_api.api is llm_api
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
with pytest.raises(
HomeAssistantError,