AI task generate_text -> generate_data (#147370)

This commit is contained in:
Paulus Schoutsen
2025-06-24 07:12:29 -04:00
committed by GitHub
parent 38c7eaf70a
commit 63ac14a19b
14 changed files with 104 additions and 100 deletions

View File

@@ -6,8 +6,8 @@ from homeassistant.components.ai_task import (
DOMAIN,
AITaskEntity,
AITaskEntityFeature,
GenTextTask,
GenTextTaskResult,
GenDataTask,
GenDataTaskResult,
)
from homeassistant.components.conversation import AssistantContent, ChatLog
from homeassistant.config_entries import ConfigEntry, ConfigFlow
@@ -33,24 +33,24 @@ class MockAITaskEntity(AITaskEntity):
"""Mock AI Task entity for testing."""
_attr_name = "Test Task Entity"
_attr_supported_features = AITaskEntityFeature.GENERATE_TEXT
_attr_supported_features = AITaskEntityFeature.GENERATE_DATA
def __init__(self) -> None:
"""Initialize the mock entity."""
super().__init__()
self.mock_generate_text_tasks = []
self.mock_generate_data_tasks = []
async def _async_generate_text(
self, task: GenTextTask, chat_log: ChatLog
) -> GenTextTaskResult:
"""Mock handling of generate text task."""
self.mock_generate_text_tasks.append(task)
async def _async_generate_data(
self, task: GenDataTask, chat_log: ChatLog
) -> GenDataTaskResult:
"""Mock handling of generate data task."""
self.mock_generate_data_tasks.append(task)
chat_log.async_add_assistant_content_without_tools(
AssistantContent(self.entity_id, "Mock result")
)
return GenTextTaskResult(
return GenDataTaskResult(
conversation_id=chat_log.conversation_id,
text="Mock result",
data="Mock result",
)

View File

@@ -1,5 +1,5 @@
# serializer version: 1
# name: test_run_text_task_updates_chat_log
# name: test_run_data_task_updates_chat_log
list([
dict({
'content': '''

View File

@@ -2,7 +2,7 @@
from freezegun import freeze_time
from homeassistant.components.ai_task import async_generate_text
from homeassistant.components.ai_task import async_generate_data
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant
@@ -12,28 +12,28 @@ from tests.common import MockConfigEntry
@freeze_time("2025-06-08 16:28:13")
async def test_state_generate_text(
async def test_state_generate_data(
hass: HomeAssistant,
init_components: None,
mock_config_entry: MockConfigEntry,
mock_ai_task_entity: MockAITaskEntity,
) -> None:
"""Test the state of the AI Task entity is updated when generating text."""
"""Test the state of the AI Task entity is updated when generating data."""
entity = hass.states.get(TEST_ENTITY_ID)
assert entity is not None
assert entity.state == STATE_UNKNOWN
result = await async_generate_text(
result = await async_generate_data(
hass,
task_name="Test task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
assert result.text == "Mock result"
assert result.data == "Mock result"
entity = hass.states.get(TEST_ENTITY_ID)
assert entity.state == "2025-06-08T16:28:13+00:00"
assert mock_ai_task_entity.mock_generate_text_tasks
task = mock_ai_task_entity.mock_generate_text_tasks[0]
assert mock_ai_task_entity.mock_generate_data_tasks
task = mock_ai_task_entity.mock_generate_data_tasks[0]
assert task.instructions == "Test prompt"

View File

@@ -18,20 +18,20 @@ async def test_ws_preferences(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": None,
"gen_data_entity_id": None,
}
# Set preferences
await client.send_json_auto_id(
{
"type": "ai_task/preferences/set",
"gen_text_entity_id": "ai_task.summary_1",
"gen_data_entity_id": "ai_task.summary_1",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_1",
"gen_data_entity_id": "ai_task.summary_1",
}
# Get updated preferences
@@ -39,20 +39,20 @@ async def test_ws_preferences(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_1",
"gen_data_entity_id": "ai_task.summary_1",
}
# Update an existing preference
await client.send_json_auto_id(
{
"type": "ai_task/preferences/set",
"gen_text_entity_id": "ai_task.summary_2",
"gen_data_entity_id": "ai_task.summary_2",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_2",
"gen_data_entity_id": "ai_task.summary_2",
}
# Get updated preferences
@@ -60,7 +60,7 @@ async def test_ws_preferences(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_2",
"gen_data_entity_id": "ai_task.summary_2",
}
# No preferences set will preserve existing preferences
@@ -72,7 +72,7 @@ async def test_ws_preferences(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_2",
"gen_data_entity_id": "ai_task.summary_2",
}
# Get updated preferences
@@ -80,5 +80,5 @@ async def test_ws_preferences(
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"gen_text_entity_id": "ai_task.summary_2",
"gen_data_entity_id": "ai_task.summary_2",
}

View File

@@ -49,7 +49,7 @@ async def test_preferences_storage_load(
("set_preferences", "msg_extra"),
[
(
{"gen_text_entity_id": TEST_ENTITY_ID},
{"gen_data_entity_id": TEST_ENTITY_ID},
{},
),
(
@@ -58,20 +58,20 @@ async def test_preferences_storage_load(
),
],
)
async def test_generate_text_service(
async def test_generate_data_service(
hass: HomeAssistant,
init_components: None,
freezer: FrozenDateTimeFactory,
set_preferences: dict[str, str | None],
msg_extra: dict[str, str],
) -> None:
"""Test the generate text service."""
"""Test the generate data service."""
preferences = hass.data[DATA_PREFERENCES]
preferences.async_set_preferences(**set_preferences)
result = await hass.services.async_call(
"ai_task",
"generate_text",
"generate_data",
{
"task_name": "Test Name",
"instructions": "Test prompt",
@@ -81,4 +81,4 @@ async def test_generate_text_service(
return_response=True,
)
assert result["text"] == "Mock result"
assert result["data"] == "Mock result"

View File

@@ -4,7 +4,7 @@ from freezegun import freeze_time
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_text
from homeassistant.components.ai_task import AITaskEntityFeature, async_generate_data
from homeassistant.components.conversation import async_get_chat_log
from homeassistant.const import STATE_UNKNOWN
from homeassistant.core import HomeAssistant
@@ -28,7 +28,7 @@ async def test_run_task_preferred_entity(
with pytest.raises(
HomeAssistantError, match="No entity_id provided and no preferred entity set"
):
await async_generate_text(
await async_generate_data(
hass,
task_name="Test Task",
instructions="Test prompt",
@@ -37,7 +37,7 @@ async def test_run_task_preferred_entity(
await client.send_json_auto_id(
{
"type": "ai_task/preferences/set",
"gen_text_entity_id": "ai_task.unknown",
"gen_data_entity_id": "ai_task.unknown",
}
)
msg = await client.receive_json()
@@ -46,7 +46,7 @@ async def test_run_task_preferred_entity(
with pytest.raises(
HomeAssistantError, match="AI Task entity ai_task.unknown not found"
):
await async_generate_text(
await async_generate_data(
hass,
task_name="Test Task",
instructions="Test prompt",
@@ -55,7 +55,7 @@ async def test_run_task_preferred_entity(
await client.send_json_auto_id(
{
"type": "ai_task/preferences/set",
"gen_text_entity_id": TEST_ENTITY_ID,
"gen_data_entity_id": TEST_ENTITY_ID,
}
)
msg = await client.receive_json()
@@ -65,12 +65,15 @@ async def test_run_task_preferred_entity(
assert state is not None
assert state.state == STATE_UNKNOWN
result = await async_generate_text(
result = await async_generate_data(
hass,
task_name="Test Task",
instructions="Test prompt",
)
assert result.text == "Mock result"
assert result.data == "Mock result"
as_dict = result.as_dict()
assert as_dict["conversation_id"] == result.conversation_id
assert as_dict["data"] == "Mock result"
state = hass.states.get(TEST_ENTITY_ID)
assert state is not None
assert state.state != STATE_UNKNOWN
@@ -78,25 +81,25 @@ async def test_run_task_preferred_entity(
mock_ai_task_entity.supported_features = AITaskEntityFeature(0)
with pytest.raises(
HomeAssistantError,
match="AI Task entity ai_task.test_task_entity does not support generating text",
match="AI Task entity ai_task.test_task_entity does not support generating data",
):
await async_generate_text(
await async_generate_data(
hass,
task_name="Test Task",
instructions="Test prompt",
)
async def test_run_text_task_unknown_entity(
async def test_run_data_task_unknown_entity(
hass: HomeAssistant,
init_components: None,
) -> None:
"""Test running a text task with an unknown entity."""
"""Test running a data task with an unknown entity."""
with pytest.raises(
HomeAssistantError, match="AI Task entity ai_task.unknown_entity not found"
):
await async_generate_text(
await async_generate_data(
hass,
task_name="Test Task",
entity_id="ai_task.unknown_entity",
@@ -105,19 +108,19 @@ async def test_run_text_task_unknown_entity(
@freeze_time("2025-06-14 22:59:00")
async def test_run_text_task_updates_chat_log(
async def test_run_data_task_updates_chat_log(
hass: HomeAssistant,
init_components: None,
snapshot: SnapshotAssertion,
) -> None:
"""Test that running a text task updates the chat log."""
result = await async_generate_text(
"""Test that running a data task updates the chat log."""
result = await async_generate_data(
hass,
task_name="Test Task",
entity_id=TEST_ENTITY_ID,
instructions="Test prompt",
)
assert result.text == "Mock result"
assert result.data == "Mock result"
with (
chat_session.async_get_chat_session(hass, result.conversation_id) as session,