mirror of
https://github.com/home-assistant/core.git
synced 2025-08-30 09:51:37 +02:00
AI Task to support LLM APIs
This commit is contained in:
@@ -77,6 +77,7 @@ class AITaskEntity(RestoreEntity):
|
||||
device_id=None,
|
||||
),
|
||||
user_llm_prompt=DEFAULT_SYSTEM_PROMPT,
|
||||
user_llm_hass_api=task.llm_api,
|
||||
)
|
||||
|
||||
chat_log.async_add_user_content(
|
||||
|
@@ -13,6 +13,7 @@ import voluptuous as vol
|
||||
from homeassistant.components import camera, conversation, media_source
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.chat_session import async_get_chat_session
|
||||
|
||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||
@@ -37,6 +38,7 @@ async def async_generate_data(
|
||||
instructions: str,
|
||||
structure: vol.Schema | None = None,
|
||||
attachments: list[dict] | None = None,
|
||||
llm_api: llm.API | None = None,
|
||||
) -> GenDataTaskResult:
|
||||
"""Run a task in the AI Task integration."""
|
||||
if entity_id is None:
|
||||
@@ -126,6 +128,7 @@ async def async_generate_data(
|
||||
instructions=instructions,
|
||||
structure=structure,
|
||||
attachments=resolved_attachments or None,
|
||||
llm_api=llm_api,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -146,6 +149,9 @@ class GenDataTask:
|
||||
attachments: list[conversation.Attachment] | None = None
|
||||
"""List of attachments to go along the instructions."""
|
||||
|
||||
llm_api: llm.API | None = None
|
||||
"""API to provide to the LLM."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return task as a string."""
|
||||
return f"<GenDataTask {self.name}: {id(self)}>"
|
||||
|
@@ -507,14 +507,18 @@ class ChatLog:
|
||||
async def async_provide_llm_data(
|
||||
self,
|
||||
llm_context: llm.LLMContext,
|
||||
user_llm_hass_api: str | list[str] | None = None,
|
||||
user_llm_hass_api: str | list[str] | llm.API | None = None,
|
||||
user_llm_prompt: str | None = None,
|
||||
user_extra_system_prompt: str | None = None,
|
||||
) -> None:
|
||||
"""Set the LLM system prompt."""
|
||||
llm_api: llm.APIInstance | None = None
|
||||
|
||||
if user_llm_hass_api:
|
||||
if user_llm_hass_api is None:
|
||||
pass
|
||||
elif isinstance(user_llm_hass_api, llm.API):
|
||||
llm_api = await user_llm_hass_api.async_get_api_instance(llm_context)
|
||||
else:
|
||||
try:
|
||||
llm_api = await llm.async_get_api(
|
||||
self.hass,
|
||||
|
@@ -15,7 +15,7 @@ from homeassistant.components.conversation import async_get_chat_log
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import chat_session
|
||||
from homeassistant.helpers import chat_session, llm
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||
@@ -73,10 +73,12 @@ async def test_generate_data_preferred_entity(
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
llm_api = llm.AssistAPI(hass)
|
||||
result = await async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
instructions="Test prompt",
|
||||
llm_api=llm_api,
|
||||
)
|
||||
assert result.data == "Mock result"
|
||||
as_dict = result.as_dict()
|
||||
@@ -86,6 +88,12 @@ async def test_generate_data_preferred_entity(
|
||||
assert state is not None
|
||||
assert state.state != STATE_UNKNOWN
|
||||
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass, result.conversation_id) as session,
|
||||
async_get_chat_log(hass, session) as chat_log,
|
||||
):
|
||||
assert chat_log.llm_api.api is llm_api
|
||||
|
||||
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
|
Reference in New Issue
Block a user