diff --git a/homeassistant/components/ai_task/entity.py b/homeassistant/components/ai_task/entity.py index 4c5cd186943..e5674ad0a9c 100644 --- a/homeassistant/components/ai_task/entity.py +++ b/homeassistant/components/ai_task/entity.py @@ -77,6 +77,7 @@ class AITaskEntity(RestoreEntity): device_id=None, ), user_llm_prompt=DEFAULT_SYSTEM_PROMPT, + user_llm_hass_api=task.llm_api, ) chat_log.async_add_user_content( diff --git a/homeassistant/components/ai_task/task.py b/homeassistant/components/ai_task/task.py index 3cc43f8c07a..d36330ff987 100644 --- a/homeassistant/components/ai_task/task.py +++ b/homeassistant/components/ai_task/task.py @@ -13,6 +13,7 @@ import voluptuous as vol from homeassistant.components import camera, conversation, media_source from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import llm from homeassistant.helpers.chat_session import async_get_chat_session from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature @@ -37,6 +38,7 @@ async def async_generate_data( instructions: str, structure: vol.Schema | None = None, attachments: list[dict] | None = None, + llm_api: llm.API | None = None, ) -> GenDataTaskResult: """Run a task in the AI Task integration.""" if entity_id is None: @@ -126,6 +128,7 @@ async def async_generate_data( instructions=instructions, structure=structure, attachments=resolved_attachments or None, + llm_api=llm_api, ), ) @@ -146,6 +149,9 @@ class GenDataTask: attachments: list[conversation.Attachment] | None = None """List of attachments to go along the instructions.""" + llm_api: llm.API | None = None + """API to provide to the LLM.""" + def __str__(self) -> str: """Return task as a string.""" return f"" diff --git a/homeassistant/components/conversation/chat_log.py b/homeassistant/components/conversation/chat_log.py index 2f5e3b0cf82..56a0b46f52b 100644 --- a/homeassistant/components/conversation/chat_log.py +++ b/homeassistant/components/conversation/chat_log.py @@ -507,14 +507,18 @@ class ChatLog: async def async_provide_llm_data( self, llm_context: llm.LLMContext, - user_llm_hass_api: str | list[str] | None = None, + user_llm_hass_api: str | list[str] | llm.API | None = None, user_llm_prompt: str | None = None, user_extra_system_prompt: str | None = None, ) -> None: """Set the LLM system prompt.""" llm_api: llm.APIInstance | None = None - if user_llm_hass_api: + if user_llm_hass_api is None: + pass + elif isinstance(user_llm_hass_api, llm.API): + llm_api = await user_llm_hass_api.async_get_api_instance(llm_context) + else: try: llm_api = await llm.async_get_api( self.hass, diff --git a/tests/components/ai_task/test_task.py b/tests/components/ai_task/test_task.py index 7eb75b62bb0..ab26e13f7e7 100644 --- a/tests/components/ai_task/test_task.py +++ b/tests/components/ai_task/test_task.py @@ -15,7 +15,7 @@ from homeassistant.components.conversation import async_get_chat_log from homeassistant.const import STATE_UNKNOWN from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError -from homeassistant.helpers import chat_session +from homeassistant.helpers import chat_session, llm from homeassistant.util import dt as dt_util from .conftest import TEST_ENTITY_ID, MockAITaskEntity @@ -73,10 +73,12 @@ async def test_generate_data_preferred_entity( assert state is not None assert state.state == STATE_UNKNOWN + llm_api = llm.AssistAPI(hass) result = await async_generate_data( hass, task_name="Test Task", instructions="Test prompt", + llm_api=llm_api, ) assert result.data == "Mock result" as_dict = result.as_dict() @@ -86,6 +88,12 @@ async def test_generate_data_preferred_entity( assert state is not None assert state.state != STATE_UNKNOWN + with ( + chat_session.async_get_chat_session(hass, result.conversation_id) as session, + async_get_chat_log(hass, session) as chat_log, + ): + assert chat_log.llm_api.api is llm_api + mock_ai_task_entity.supported_features = AITaskEntityFeature(0) with pytest.raises( HomeAssistantError,