mirror of
https://github.com/home-assistant/core.git
synced 2025-08-03 20:55:10 +02:00
AI task generate_text -> generate_data (#147370)
This commit is contained in:
@@ -6,8 +6,8 @@ from homeassistant.components.ai_task import (
|
||||
DOMAIN,
|
||||
AITaskEntity,
|
||||
AITaskEntityFeature,
|
||||
GenTextTask,
|
||||
GenTextTaskResult,
|
||||
GenDataTask,
|
||||
GenDataTaskResult,
|
||||
)
|
||||
from homeassistant.components.conversation import AssistantContent, ChatLog
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
@@ -33,24 +33,24 @@ class MockAITaskEntity(AITaskEntity):
|
||||
"""Mock AI Task entity for testing."""
|
||||
|
||||
_attr_name = "Test Task Entity"
|
||||
_attr_supported_features = AITaskEntityFeature.GENERATE_TEXT
|
||||
_attr_supported_features = AITaskEntityFeature.GENERATE_DATA
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the mock entity."""
|
||||
super().__init__()
|
||||
self.mock_generate_text_tasks = []
|
||||
self.mock_generate_data_tasks = []
|
||||
|
||||
async def _async_generate_text(
|
||||
self, task: GenTextTask, chat_log: ChatLog
|
||||
) -> GenTextTaskResult:
|
||||
"""Mock handling of generate text task."""
|
||||
self.mock_generate_text_tasks.append(task)
|
||||
async def _async_generate_data(
|
||||
self, task: GenDataTask, chat_log: ChatLog
|
||||
) -> GenDataTaskResult:
|
||||
"""Mock handling of generate data task."""
|
||||
self.mock_generate_data_tasks.append(task)
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
AssistantContent(self.entity_id, "Mock result")
|
||||
)
|
||||
return GenTextTaskResult(
|
||||
return GenDataTaskResult(
|
||||
conversation_id=chat_log.conversation_id,
|
||||
text="Mock result",
|
||||
data="Mock result",
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
# serializer version: 1
|
||||
# name: test_run_text_task_updates_chat_log
|
||||
# name: test_run_data_task_updates_chat_log
|
||||
list([
|
||||
dict({
|
||||
'content': '''
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
from freezegun import freeze_time
|
||||
|
||||
from homeassistant.components.ai_task import async_generate_text
|
||||
from homeassistant.components.ai_task import async_generate_data
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
@@ -12,28 +12,28 @@ from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@freeze_time("2025-06-08 16:28:13")
|
||||
async def test_state_generate_text(
|
||||
async def test_state_generate_data(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_ai_task_entity: MockAITaskEntity,
|
||||
) -> None:
|
||||
"""Test the state of the AI Task entity is updated when generating text."""
|
||||
"""Test the state of the AI Task entity is updated when generating data."""
|
||||
entity = hass.states.get(TEST_ENTITY_ID)
|
||||
assert entity is not None
|
||||
assert entity.state == STATE_UNKNOWN
|
||||
|
||||
result = await async_generate_text(
|
||||
result = await async_generate_data(
|
||||
hass,
|
||||
task_name="Test task",
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
assert result.text == "Mock result"
|
||||
assert result.data == "Mock result"
|
||||
|
||||
entity = hass.states.get(TEST_ENTITY_ID)
|
||||
assert entity.state == "2025-06-08T16:28:13+00:00"
|
||||
|
||||
assert mock_ai_task_entity.mock_generate_text_tasks
|
||||
task = mock_ai_task_entity.mock_generate_text_tasks[0]
|
||||
assert mock_ai_task_entity.mock_generate_data_tasks
|
||||
task = mock_ai_task_entity.mock_generate_data_tasks[0]
|
||||
assert task.instructions == "Test prompt"
|
||||
|
@@ -18,20 +18,20 @@ async def test_ws_preferences(
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": None,
|
||||
"gen_data_entity_id": None,
|
||||
}
|
||||
|
||||
# Set preferences
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_text_entity_id": "ai_task.summary_1",
|
||||
"gen_data_entity_id": "ai_task.summary_1",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_1",
|
||||
"gen_data_entity_id": "ai_task.summary_1",
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
@@ -39,20 +39,20 @@ async def test_ws_preferences(
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_1",
|
||||
"gen_data_entity_id": "ai_task.summary_1",
|
||||
}
|
||||
|
||||
# Update an existing preference
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
"gen_data_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
"gen_data_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
@@ -60,7 +60,7 @@ async def test_ws_preferences(
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
"gen_data_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
|
||||
# No preferences set will preserve existing preferences
|
||||
@@ -72,7 +72,7 @@ async def test_ws_preferences(
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
"gen_data_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
|
||||
# Get updated preferences
|
||||
@@ -80,5 +80,5 @@ async def test_ws_preferences(
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"gen_text_entity_id": "ai_task.summary_2",
|
||||
"gen_data_entity_id": "ai_task.summary_2",
|
||||
}
|
||||
|
@@ -49,7 +49,7 @@ async def test_preferences_storage_load(
|
||||
("set_preferences", "msg_extra"),
|
||||
[
|
||||
(
|
||||
{"gen_text_entity_id": TEST_ENTITY_ID},
|
||||
{"gen_data_entity_id": TEST_ENTITY_ID},
|
||||
{},
|
||||
),
|
||||
(
|
||||
@@ -58,20 +58,20 @@ async def test_preferences_storage_load(
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_generate_text_service(
|
||||
async def test_generate_data_service(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
freezer: FrozenDateTimeFactory,
|
||||
set_preferences: dict[str, str | None],
|
||||
msg_extra: dict[str, str],
|
||||
) -> None:
|
||||
"""Test the generate text service."""
|
||||
"""Test the generate data service."""
|
||||
preferences = hass.data[DATA_PREFERENCES]
|
||||
preferences.async_set_preferences(**set_preferences)
|
||||
|
||||
result = await hass.services.async_call(
|
||||
"ai_task",
|
||||
"generate_text",
|
||||
"generate_data",
|
||||
{
|
||||
"task_name": "Test Name",
|
||||
"instructions": "Test prompt",
|
||||
@@ -81,4 +81,4 @@ async def test_generate_text_service(
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
assert result["text"] == "Mock result"
|
||||
assert result["data"] == "Mock result"
|
||||
|
@@ -4,7 +4,7 @@ from freezegun import freeze_time
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_text
|
||||
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data
|
||||
from homeassistant.components.conversation import async_get_chat_log
|
||||
from homeassistant.const import STATE_UNKNOWN
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -28,7 +28,7 @@ async def test_run_task_preferred_entity(
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match="No entity_id provided and no preferred entity set"
|
||||
):
|
||||
await async_generate_text(
|
||||
await async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
instructions="Test prompt",
|
||||
@@ -37,7 +37,7 @@ async def test_run_task_preferred_entity(
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_text_entity_id": "ai_task.unknown",
|
||||
"gen_data_entity_id": "ai_task.unknown",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
@@ -46,7 +46,7 @@ async def test_run_task_preferred_entity(
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match="AI Task entity ai_task.unknown not found"
|
||||
):
|
||||
await async_generate_text(
|
||||
await async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
instructions="Test prompt",
|
||||
@@ -55,7 +55,7 @@ async def test_run_task_preferred_entity(
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "ai_task/preferences/set",
|
||||
"gen_text_entity_id": TEST_ENTITY_ID,
|
||||
"gen_data_entity_id": TEST_ENTITY_ID,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
@@ -65,12 +65,15 @@ async def test_run_task_preferred_entity(
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNKNOWN
|
||||
|
||||
result = await async_generate_text(
|
||||
result = await async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
instructions="Test prompt",
|
||||
)
|
||||
assert result.text == "Mock result"
|
||||
assert result.data == "Mock result"
|
||||
as_dict = result.as_dict()
|
||||
assert as_dict["conversation_id"] == result.conversation_id
|
||||
assert as_dict["data"] == "Mock result"
|
||||
state = hass.states.get(TEST_ENTITY_ID)
|
||||
assert state is not None
|
||||
assert state.state != STATE_UNKNOWN
|
||||
@@ -78,25 +81,25 @@ async def test_run_task_preferred_entity(
|
||||
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
|
||||
with pytest.raises(
|
||||
HomeAssistantError,
|
||||
match="AI Task entity ai_task.test_task_entity does not support generating text",
|
||||
match="AI Task entity ai_task.test_task_entity does not support generating data",
|
||||
):
|
||||
await async_generate_text(
|
||||
await async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
instructions="Test prompt",
|
||||
)
|
||||
|
||||
|
||||
async def test_run_text_task_unknown_entity(
|
||||
async def test_run_data_task_unknown_entity(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
) -> None:
|
||||
"""Test running a text task with an unknown entity."""
|
||||
"""Test running a data task with an unknown entity."""
|
||||
|
||||
with pytest.raises(
|
||||
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
|
||||
):
|
||||
await async_generate_text(
|
||||
await async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id="ai_task.unknown_entity",
|
||||
@@ -105,19 +108,19 @@ async def test_run_text_task_unknown_entity(
|
||||
|
||||
|
||||
@freeze_time("2025-06-14 22:59:00")
|
||||
async def test_run_text_task_updates_chat_log(
|
||||
async def test_run_data_task_updates_chat_log(
|
||||
hass: HomeAssistant,
|
||||
init_components: None,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that running a text task updates the chat log."""
|
||||
result = await async_generate_text(
|
||||
"""Test that running a data task updates the chat log."""
|
||||
result = await async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id=TEST_ENTITY_ID,
|
||||
instructions="Test prompt",
|
||||
)
|
||||
assert result.text == "Mock result"
|
||||
assert result.data == "Mock result"
|
||||
|
||||
with (
|
||||
chat_session.async_get_chat_session(hass, result.conversation_id) as session,
|
||||
|
Reference in New Issue
Block a user