mirror of
https://github.com/home-assistant/core.git
synced 2025-08-31 02:11:32 +02:00
AI Task to support LLM APIs
This commit is contained in:
@@ -77,6 +77,7 @@ class AITaskEntity(RestoreEntity):
|
|||||||
device_id=None,
|
device_id=None,
|
||||||
),
|
),
|
||||||
user_llm_prompt=DEFAULT_SYSTEM_PROMPT,
|
user_llm_prompt=DEFAULT_SYSTEM_PROMPT,
|
||||||
|
user_llm_hass_api=task.llm_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_log.async_add_user_content(
|
chat_log.async_add_user_content(
|
||||||
|
@@ -13,6 +13,7 @@ import voluptuous as vol
|
|||||||
from homeassistant.components import camera, conversation, media_source
|
from homeassistant.components import camera, conversation, media_source
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers import llm
|
||||||
from homeassistant.helpers.chat_session import async_get_chat_session
|
from homeassistant.helpers.chat_session import async_get_chat_session
|
||||||
|
|
||||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||||
@@ -37,6 +38,7 @@ async def async_generate_data(
|
|||||||
instructions: str,
|
instructions: str,
|
||||||
structure: vol.Schema | None = None,
|
structure: vol.Schema | None = None,
|
||||||
attachments: list[dict] | None = None,
|
attachments: list[dict] | None = None,
|
||||||
|
llm_api: llm.API | None = None,
|
||||||
) -> GenDataTaskResult:
|
) -> GenDataTaskResult:
|
||||||
"""Run a task in the AI Task integration."""
|
"""Run a task in the AI Task integration."""
|
||||||
if entity_id is None:
|
if entity_id is None:
|
||||||
@@ -126,6 +128,7 @@ async def async_generate_data(
|
|||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
structure=structure,
|
structure=structure,
|
||||||
attachments=resolved_attachments or None,
|
attachments=resolved_attachments or None,
|
||||||
|
llm_api=llm_api,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -146,6 +149,9 @@ class GenDataTask:
|
|||||||
attachments: list[conversation.Attachment] | None = None
|
attachments: list[conversation.Attachment] | None = None
|
||||||
"""List of attachments to go along the instructions."""
|
"""List of attachments to go along the instructions."""
|
||||||
|
|
||||||
|
llm_api: llm.API | None = None
|
||||||
|
"""API to provide to the LLM."""
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""Return task as a string."""
|
"""Return task as a string."""
|
||||||
return f"<GenDataTask {self.name}: {id(self)}>"
|
return f"<GenDataTask {self.name}: {id(self)}>"
|
||||||
|
@@ -507,14 +507,18 @@ class ChatLog:
|
|||||||
async def async_provide_llm_data(
|
async def async_provide_llm_data(
|
||||||
self,
|
self,
|
||||||
llm_context: llm.LLMContext,
|
llm_context: llm.LLMContext,
|
||||||
user_llm_hass_api: str | list[str] | None = None,
|
user_llm_hass_api: str | list[str] | llm.API | None = None,
|
||||||
user_llm_prompt: str | None = None,
|
user_llm_prompt: str | None = None,
|
||||||
user_extra_system_prompt: str | None = None,
|
user_extra_system_prompt: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set the LLM system prompt."""
|
"""Set the LLM system prompt."""
|
||||||
llm_api: llm.APIInstance | None = None
|
llm_api: llm.APIInstance | None = None
|
||||||
|
|
||||||
if user_llm_hass_api:
|
if user_llm_hass_api is None:
|
||||||
|
pass
|
||||||
|
elif isinstance(user_llm_hass_api, llm.API):
|
||||||
|
llm_api = await user_llm_hass_api.async_get_api_instance(llm_context)
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
llm_api = await llm.async_get_api(
|
llm_api = await llm.async_get_api(
|
||||||
self.hass,
|
self.hass,
|
||||||
|
@@ -15,7 +15,7 @@ from homeassistant.components.conversation import async_get_chat_log
|
|||||||
from homeassistant.const import STATE_UNKNOWN
|
from homeassistant.const import STATE_UNKNOWN
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import chat_session
|
from homeassistant.helpers import chat_session, llm
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
from .conftest import TEST_ENTITY_ID, MockAITaskEntity
|
||||||
@@ -73,10 +73,12 @@ async def test_generate_data_preferred_entity(
|
|||||||
assert state is not None
|
assert state is not None
|
||||||
assert state.state == STATE_UNKNOWN
|
assert state.state == STATE_UNKNOWN
|
||||||
|
|
||||||
|
llm_api = llm.AssistAPI(hass)
|
||||||
result = await async_generate_data(
|
result = await async_generate_data(
|
||||||
hass,
|
hass,
|
||||||
task_name="Test Task",
|
task_name="Test Task",
|
||||||
instructions="Test prompt",
|
instructions="Test prompt",
|
||||||
|
llm_api=llm_api,
|
||||||
)
|
)
|
||||||
assert result.data == "Mock result"
|
assert result.data == "Mock result"
|
||||||
as_dict = result.as_dict()
|
as_dict = result.as_dict()
|
||||||
@@ -86,6 +88,12 @@ async def test_generate_data_preferred_entity(
|
|||||||
assert state is not None
|
assert state is not None
|
||||||
assert state.state != STATE_UNKNOWN
|
assert state.state != STATE_UNKNOWN
|
||||||
|
|
||||||
|
with (
|
||||||
|
chat_session.async_get_chat_session(hass, result.conversation_id) as session,
|
||||||
|
async_get_chat_log(hass, session) as chat_log,
|
||||||
|
):
|
||||||
|
assert chat_log.llm_api.api is llm_api
|
||||||
|
|
||||||
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
|
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
HomeAssistantError,
|
HomeAssistantError,
|
||||||
|
Reference in New Issue
Block a user