mirror of
https://github.com/home-assistant/core.git
synced 2026-02-07 07:44:50 +01:00
Anthropic repair deprecated models (#162162)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -14,10 +14,18 @@ from homeassistant.helpers import (
|
||||
config_validation as cv,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
issue_registry as ir,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DEFAULT_CONVERSATION_NAME, DOMAIN, LOGGER
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
DATA_REPAIR_DEFER_RELOAD,
|
||||
DEFAULT_CONVERSATION_NAME,
|
||||
DEPRECATED_MODELS,
|
||||
DOMAIN,
|
||||
LOGGER,
|
||||
)
|
||||
|
||||
PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION)
|
||||
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
|
||||
@@ -27,6 +35,7 @@ type AnthropicConfigEntry = ConfigEntry[anthropic.AsyncClient]
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up Anthropic."""
|
||||
hass.data.setdefault(DOMAIN, {}).setdefault(DATA_REPAIR_DEFER_RELOAD, set())
|
||||
await async_migrate_integration(hass)
|
||||
return True
|
||||
|
||||
@@ -50,6 +59,22 @@ async def async_setup_entry(hass: HomeAssistant, entry: AnthropicConfigEntry) ->
|
||||
|
||||
entry.async_on_unload(entry.add_update_listener(async_update_options))
|
||||
|
||||
for subentry in entry.subentries.values():
|
||||
if (model := subentry.data.get(CONF_CHAT_MODEL)) and model.startswith(
|
||||
tuple(DEPRECATED_MODELS)
|
||||
):
|
||||
ir.async_create_issue(
|
||||
hass,
|
||||
DOMAIN,
|
||||
"model_deprecated",
|
||||
is_fixable=True,
|
||||
is_persistent=False,
|
||||
learn_more_url="https://platform.claude.com/docs/en/about-claude/model-deprecations",
|
||||
severity=ir.IssueSeverity.WARNING,
|
||||
translation_key="model_deprecated",
|
||||
)
|
||||
break
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -62,6 +87,11 @@ async def async_update_options(
|
||||
hass: HomeAssistant, entry: AnthropicConfigEntry
|
||||
) -> None:
|
||||
"""Update options."""
|
||||
defer_reload_entries: set[str] = hass.data.setdefault(DOMAIN, {}).setdefault(
|
||||
DATA_REPAIR_DEFER_RELOAD, set()
|
||||
)
|
||||
if entry.entry_id in defer_reload_entries:
|
||||
return
|
||||
await hass.config_entries.async_reload(entry.entry_id)
|
||||
|
||||
|
||||
|
||||
@@ -92,6 +92,40 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
|
||||
await client.models.list(timeout=10.0)
|
||||
|
||||
|
||||
async def get_model_list(client: anthropic.AsyncAnthropic) -> list[SelectOptionDict]:
|
||||
"""Get list of available models."""
|
||||
try:
|
||||
models = (await client.models.list()).data
|
||||
except anthropic.AnthropicError:
|
||||
models = []
|
||||
_LOGGER.debug("Available models: %s", models)
|
||||
model_options: list[SelectOptionDict] = []
|
||||
short_form = re.compile(r"[^\d]-\d$")
|
||||
for model_info in models:
|
||||
# Resolve alias from versioned model name:
|
||||
model_alias = (
|
||||
model_info.id[:-9]
|
||||
if model_info.id
|
||||
not in (
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-opus-20240229",
|
||||
)
|
||||
else model_info.id
|
||||
)
|
||||
if short_form.search(model_alias):
|
||||
model_alias += "-0"
|
||||
if model_alias.endswith(("haiku", "opus", "sonnet")):
|
||||
model_alias += "-latest"
|
||||
model_options.append(
|
||||
SelectOptionDict(
|
||||
label=model_info.display_name,
|
||||
value=model_alias,
|
||||
)
|
||||
)
|
||||
return model_options
|
||||
|
||||
|
||||
class AnthropicConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for Anthropic."""
|
||||
|
||||
@@ -401,42 +435,13 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
|
||||
async def _get_model_list(self) -> list[SelectOptionDict]:
|
||||
"""Get list of available models."""
|
||||
try:
|
||||
client = await self.hass.async_add_executor_job(
|
||||
partial(
|
||||
anthropic.AsyncAnthropic,
|
||||
api_key=self._get_entry().data[CONF_API_KEY],
|
||||
)
|
||||
client = await self.hass.async_add_executor_job(
|
||||
partial(
|
||||
anthropic.AsyncAnthropic,
|
||||
api_key=self._get_entry().data[CONF_API_KEY],
|
||||
)
|
||||
models = (await client.models.list()).data
|
||||
except anthropic.AnthropicError:
|
||||
models = []
|
||||
_LOGGER.debug("Available models: %s", models)
|
||||
model_options: list[SelectOptionDict] = []
|
||||
short_form = re.compile(r"[^\d]-\d$")
|
||||
for model_info in models:
|
||||
# Resolve alias from versioned model name:
|
||||
model_alias = (
|
||||
model_info.id[:-9]
|
||||
if model_info.id
|
||||
not in (
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-opus-20240229",
|
||||
)
|
||||
else model_info.id
|
||||
)
|
||||
if short_form.search(model_alias):
|
||||
model_alias += "-0"
|
||||
if model_alias.endswith(("haiku", "opus", "sonnet")):
|
||||
model_alias += "-latest"
|
||||
model_options.append(
|
||||
SelectOptionDict(
|
||||
label=model_info.display_name,
|
||||
value=model_alias,
|
||||
)
|
||||
)
|
||||
return model_options
|
||||
)
|
||||
return await get_model_list(client)
|
||||
|
||||
async def _get_location_data(self) -> dict[str, str]:
|
||||
"""Get approximate location data of the user."""
|
||||
|
||||
@@ -22,6 +22,8 @@ CONF_WEB_SEARCH_REGION = "region"
|
||||
CONF_WEB_SEARCH_COUNTRY = "country"
|
||||
CONF_WEB_SEARCH_TIMEZONE = "timezone"
|
||||
|
||||
DATA_REPAIR_DEFER_RELOAD = "repair_defer_reload"
|
||||
|
||||
DEFAULT = {
|
||||
CONF_CHAT_MODEL: "claude-haiku-4-5",
|
||||
CONF_MAX_TOKENS: 3000,
|
||||
@@ -46,3 +48,10 @@ WEB_SEARCH_UNSUPPORTED_MODELS = [
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
]
|
||||
|
||||
DEPRECATED_MODELS = [
|
||||
"claude-3-5-haiku",
|
||||
"claude-3-7-sonnet",
|
||||
"claude-3-5-sonnet",
|
||||
"claude-3-opus",
|
||||
]
|
||||
|
||||
275
homeassistant/components/anthropic/repairs.py
Normal file
275
homeassistant/components/anthropic/repairs.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""Issue repair flow for Anthropic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant.components.repairs import RepairsFlow
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigSubentry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.selector import SelectSelector, SelectSelectorConfig
|
||||
|
||||
from .config_flow import get_model_list
|
||||
from .const import (
|
||||
CONF_CHAT_MODEL,
|
||||
DATA_REPAIR_DEFER_RELOAD,
|
||||
DEFAULT,
|
||||
DEPRECATED_MODELS,
|
||||
DOMAIN,
|
||||
)
|
||||
|
||||
|
||||
class ModelDeprecatedRepairFlow(RepairsFlow):
|
||||
"""Handler for an issue fixing flow."""
|
||||
|
||||
_subentry_iter: Iterator[tuple[str, str]] | None
|
||||
_current_entry_id: str | None
|
||||
_current_subentry_id: str | None
|
||||
_reload_pending: set[str]
|
||||
_pending_updates: dict[str, dict[str, str]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the flow."""
|
||||
super().__init__()
|
||||
self._subentry_iter = None
|
||||
self._current_entry_id = None
|
||||
self._current_subentry_id = None
|
||||
self._reload_pending = set()
|
||||
self._pending_updates = {}
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> data_entry_flow.FlowResult:
|
||||
"""Handle the first step of a fix flow."""
|
||||
previous_entry_id: str | None = None
|
||||
if user_input is not None:
|
||||
previous_entry_id = self._async_update_current_subentry(user_input)
|
||||
self._clear_current_target()
|
||||
|
||||
target = await self._async_next_target()
|
||||
next_entry_id = target[0].entry_id if target else None
|
||||
if previous_entry_id and previous_entry_id != next_entry_id:
|
||||
await self._async_apply_pending_updates(previous_entry_id)
|
||||
if target is None:
|
||||
await self._async_apply_all_pending_updates()
|
||||
return self.async_create_entry(data={})
|
||||
|
||||
entry, subentry, model = target
|
||||
client = entry.runtime_data
|
||||
model_list = [
|
||||
model_option
|
||||
for model_option in await get_model_list(client)
|
||||
if not model_option["value"].startswith(tuple(DEPRECATED_MODELS))
|
||||
]
|
||||
|
||||
if "opus" in model:
|
||||
suggested_model = "claude-opus-4-5"
|
||||
elif "haiku" in model:
|
||||
suggested_model = "claude-haiku-4-5"
|
||||
elif "sonnet" in model:
|
||||
suggested_model = "claude-sonnet-4-5"
|
||||
else:
|
||||
suggested_model = cast(str, DEFAULT[CONF_CHAT_MODEL])
|
||||
|
||||
schema = vol.Schema(
|
||||
{
|
||||
vol.Required(
|
||||
CONF_CHAT_MODEL,
|
||||
default=suggested_model,
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(options=model_list, custom_value=True)
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=schema,
|
||||
description_placeholders={
|
||||
"entry_name": entry.title,
|
||||
"model": model,
|
||||
"subentry_name": subentry.title,
|
||||
"subentry_type": self._format_subentry_type(subentry.subentry_type),
|
||||
},
|
||||
)
|
||||
|
||||
def _iter_deprecated_subentries(self) -> Iterator[tuple[str, str]]:
|
||||
"""Yield entry/subentry pairs that use deprecated models."""
|
||||
for entry in self.hass.config_entries.async_entries(DOMAIN):
|
||||
if entry.state is not ConfigEntryState.LOADED:
|
||||
continue
|
||||
for subentry in entry.subentries.values():
|
||||
model = subentry.data.get(CONF_CHAT_MODEL)
|
||||
if model and model.startswith(tuple(DEPRECATED_MODELS)):
|
||||
yield entry.entry_id, subentry.subentry_id
|
||||
|
||||
async def _async_next_target(
|
||||
self,
|
||||
) -> tuple[ConfigEntry, ConfigSubentry, str] | None:
|
||||
"""Return the next deprecated subentry target."""
|
||||
if self._subentry_iter is None:
|
||||
self._subentry_iter = self._iter_deprecated_subentries()
|
||||
|
||||
while True:
|
||||
try:
|
||||
entry_id, subentry_id = next(self._subentry_iter)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
entry = self.hass.config_entries.async_get_entry(entry_id)
|
||||
if entry is None:
|
||||
continue
|
||||
|
||||
subentry = entry.subentries.get(subentry_id)
|
||||
if subentry is None:
|
||||
continue
|
||||
|
||||
model = self._pending_model(entry_id, subentry_id)
|
||||
if model is None:
|
||||
model = subentry.data.get(CONF_CHAT_MODEL)
|
||||
if not model or not model.startswith(tuple(DEPRECATED_MODELS)):
|
||||
continue
|
||||
|
||||
self._current_entry_id = entry_id
|
||||
self._current_subentry_id = subentry_id
|
||||
return entry, subentry, model
|
||||
|
||||
def _async_update_current_subentry(self, user_input: dict[str, str]) -> str | None:
|
||||
"""Update the currently selected subentry."""
|
||||
if not self._current_entry_id or not self._current_subentry_id:
|
||||
return None
|
||||
|
||||
entry = self.hass.config_entries.async_get_entry(self._current_entry_id)
|
||||
if entry is None:
|
||||
return None
|
||||
|
||||
subentry = entry.subentries.get(self._current_subentry_id)
|
||||
if subentry is None:
|
||||
return None
|
||||
|
||||
updated_data = {
|
||||
**subentry.data,
|
||||
CONF_CHAT_MODEL: user_input[CONF_CHAT_MODEL],
|
||||
}
|
||||
if updated_data == subentry.data:
|
||||
return entry.entry_id
|
||||
self._queue_pending_update(
|
||||
entry.entry_id,
|
||||
subentry.subentry_id,
|
||||
updated_data[CONF_CHAT_MODEL],
|
||||
)
|
||||
return entry.entry_id
|
||||
|
||||
def _clear_current_target(self) -> None:
|
||||
"""Clear current target tracking."""
|
||||
self._current_entry_id = None
|
||||
self._current_subentry_id = None
|
||||
|
||||
def _format_subentry_type(self, subentry_type: str) -> str:
|
||||
"""Return a user-friendly subentry type label."""
|
||||
if subentry_type == "conversation":
|
||||
return "Conversation agent"
|
||||
if subentry_type in ("ai_task", "ai_task_data"):
|
||||
return "AI task"
|
||||
return subentry_type
|
||||
|
||||
def _queue_pending_update(
|
||||
self, entry_id: str, subentry_id: str, model: str
|
||||
) -> None:
|
||||
"""Store a pending model update for a subentry."""
|
||||
self._pending_updates.setdefault(entry_id, {})[subentry_id] = model
|
||||
|
||||
def _pending_model(self, entry_id: str, subentry_id: str) -> str | None:
|
||||
"""Return a pending model update if one exists."""
|
||||
return self._pending_updates.get(entry_id, {}).get(subentry_id)
|
||||
|
||||
def _mark_entry_for_reload(self, entry_id: str) -> None:
|
||||
"""Prevent reload until repairs are complete for the entry."""
|
||||
self._reload_pending.add(entry_id)
|
||||
defer_reload_entries: set[str] = self.hass.data.setdefault(
|
||||
DOMAIN, {}
|
||||
).setdefault(DATA_REPAIR_DEFER_RELOAD, set())
|
||||
defer_reload_entries.add(entry_id)
|
||||
|
||||
async def _async_reload_entry(self, entry_id: str) -> None:
|
||||
"""Reload an entry once all repairs are completed."""
|
||||
if entry_id not in self._reload_pending:
|
||||
return
|
||||
|
||||
entry = self.hass.config_entries.async_get_entry(entry_id)
|
||||
if entry is not None and entry.state is not ConfigEntryState.LOADED:
|
||||
self._clear_defer_reload(entry_id)
|
||||
self._reload_pending.discard(entry_id)
|
||||
return
|
||||
|
||||
if entry is not None:
|
||||
await self.hass.config_entries.async_reload(entry_id)
|
||||
|
||||
self._clear_defer_reload(entry_id)
|
||||
self._reload_pending.discard(entry_id)
|
||||
|
||||
def _clear_defer_reload(self, entry_id: str) -> None:
|
||||
"""Remove entry from the deferred reload set."""
|
||||
defer_reload_entries: set[str] = self.hass.data.setdefault(
|
||||
DOMAIN, {}
|
||||
).setdefault(DATA_REPAIR_DEFER_RELOAD, set())
|
||||
defer_reload_entries.discard(entry_id)
|
||||
|
||||
async def _async_apply_pending_updates(self, entry_id: str) -> None:
|
||||
"""Apply pending subentry updates for a single entry."""
|
||||
updates = self._pending_updates.pop(entry_id, None)
|
||||
if not updates:
|
||||
return
|
||||
|
||||
entry = self.hass.config_entries.async_get_entry(entry_id)
|
||||
if entry is None or entry.state is not ConfigEntryState.LOADED:
|
||||
return
|
||||
|
||||
changed = False
|
||||
for subentry_id, model in updates.items():
|
||||
subentry = entry.subentries.get(subentry_id)
|
||||
if subentry is None:
|
||||
continue
|
||||
|
||||
updated_data = {
|
||||
**subentry.data,
|
||||
CONF_CHAT_MODEL: model,
|
||||
}
|
||||
if updated_data == subentry.data:
|
||||
continue
|
||||
|
||||
if not changed:
|
||||
self._mark_entry_for_reload(entry_id)
|
||||
changed = True
|
||||
|
||||
self.hass.config_entries.async_update_subentry(
|
||||
entry,
|
||||
subentry,
|
||||
data=updated_data,
|
||||
)
|
||||
|
||||
if not changed:
|
||||
return
|
||||
|
||||
await self._async_reload_entry(entry_id)
|
||||
|
||||
async def _async_apply_all_pending_updates(self) -> None:
|
||||
"""Apply all pending updates across entries."""
|
||||
for entry_id in list(self._pending_updates):
|
||||
await self._async_apply_pending_updates(entry_id)
|
||||
|
||||
|
||||
async def async_create_fix_flow(
|
||||
hass: HomeAssistant,
|
||||
issue_id: str,
|
||||
data: dict[str, str | int | float | None] | None,
|
||||
) -> RepairsFlow:
|
||||
"""Create flow."""
|
||||
if issue_id == "model_deprecated":
|
||||
return ModelDeprecatedRepairFlow()
|
||||
raise HomeAssistantError("Unknown issue ID")
|
||||
@@ -109,5 +109,21 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"issues": {
|
||||
"model_deprecated": {
|
||||
"fix_flow": {
|
||||
"step": {
|
||||
"init": {
|
||||
"data": {
|
||||
"chat_model": "[%key:common::generic::model%]"
|
||||
},
|
||||
"description": "You are updating {subentry_name} ({subentry_type}) in {entry_name}. The current model {model} is deprecated. Select a supported model to continue.",
|
||||
"title": "Update model"
|
||||
}
|
||||
}
|
||||
},
|
||||
"title": "Model deprecated"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
301
tests/components/anthropic/test_repairs.py
Normal file
301
tests/components/anthropic/test_repairs.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""Tests for the Anthropic repairs flow."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
from homeassistant.components.anthropic.const import CONF_CHAT_MODEL, DOMAIN
|
||||
from homeassistant.config_entries import ConfigEntryState, ConfigSubentry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
from homeassistant.helpers import issue_registry as ir
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.components.repairs import (
|
||||
async_process_repairs_platforms,
|
||||
process_repair_fix_flow,
|
||||
start_repair_fix_flow,
|
||||
)
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
def _make_entry(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
title: str,
|
||||
api_key: str,
|
||||
subentries_data: list[dict[str, Any]],
|
||||
) -> MockConfigEntry:
|
||||
"""Create a config entry with subentries and runtime data."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
title=title,
|
||||
data={"api_key": api_key},
|
||||
version=2,
|
||||
subentries_data=subentries_data,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
object.__setattr__(entry, "state", ConfigEntryState.LOADED)
|
||||
entry.runtime_data = MagicMock()
|
||||
return entry
|
||||
|
||||
|
||||
def _get_subentry(
|
||||
entry: MockConfigEntry,
|
||||
subentry_type: str,
|
||||
) -> ConfigSubentry:
|
||||
"""Return the first subentry of a type."""
|
||||
return next(
|
||||
subentry
|
||||
for subentry in entry.subentries.values()
|
||||
if subentry.subentry_type == subentry_type
|
||||
)
|
||||
|
||||
|
||||
async def _setup_repairs(hass: HomeAssistant) -> None:
|
||||
hass.config.components.add(DOMAIN)
|
||||
assert await async_setup_component(hass, "repairs", {})
|
||||
await async_process_repairs_platforms(hass)
|
||||
|
||||
|
||||
async def test_repair_flow_iterates_subentries(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
issue_registry: ir.IssueRegistry,
|
||||
) -> None:
|
||||
"""Test the repair flow iterates across deprecated subentries."""
|
||||
entry_one: MockConfigEntry = _make_entry(
|
||||
hass,
|
||||
title="Entry One",
|
||||
api_key="key-one",
|
||||
subentries_data=[
|
||||
{
|
||||
"data": {CONF_CHAT_MODEL: "claude-3-5-haiku-20241022"},
|
||||
"subentry_type": "conversation",
|
||||
"title": "Conversation One",
|
||||
"unique_id": None,
|
||||
},
|
||||
{
|
||||
"data": {CONF_CHAT_MODEL: "claude-3-5-sonnet-20241022"},
|
||||
"subentry_type": "ai_task_data",
|
||||
"title": "AI task One",
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
entry_two: MockConfigEntry = _make_entry(
|
||||
hass,
|
||||
title="Entry Two",
|
||||
api_key="key-two",
|
||||
subentries_data=[
|
||||
{
|
||||
"data": {CONF_CHAT_MODEL: "claude-3-opus-20240229"},
|
||||
"subentry_type": "conversation",
|
||||
"title": "Conversation Two",
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
ir.async_create_issue(
|
||||
hass,
|
||||
DOMAIN,
|
||||
"model_deprecated",
|
||||
is_fixable=True,
|
||||
is_persistent=False,
|
||||
severity=ir.IssueSeverity.WARNING,
|
||||
translation_key="model_deprecated",
|
||||
)
|
||||
|
||||
await _setup_repairs(hass)
|
||||
client = await hass_client()
|
||||
|
||||
model_options: list[dict[str, str]] = [
|
||||
{"label": "Claude Haiku 4.5", "value": "claude-haiku-4-5"},
|
||||
{"label": "Claude Sonnet 4.5", "value": "claude-sonnet-4-5"},
|
||||
{"label": "Claude Opus 4.5", "value": "claude-opus-4-5"},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.anthropic.repairs.get_model_list",
|
||||
new_callable=AsyncMock,
|
||||
return_value=model_options,
|
||||
):
|
||||
result = await start_repair_fix_flow(client, DOMAIN, "model_deprecated")
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "init"
|
||||
placeholders = result["description_placeholders"]
|
||||
assert placeholders["entry_name"] == entry_one.title
|
||||
assert placeholders["subentry_name"] == "Conversation One"
|
||||
assert placeholders["subentry_type"] == "Conversation agent"
|
||||
|
||||
flow_id = result["flow_id"]
|
||||
|
||||
result = await process_repair_fix_flow(
|
||||
client,
|
||||
flow_id,
|
||||
json={CONF_CHAT_MODEL: "claude-haiku-4-5"},
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert (
|
||||
_get_subentry(entry_one, "conversation").data[CONF_CHAT_MODEL]
|
||||
== "claude-3-5-haiku-20241022"
|
||||
)
|
||||
|
||||
placeholders = result["description_placeholders"]
|
||||
assert placeholders["entry_name"] == entry_one.title
|
||||
assert placeholders["subentry_name"] == "AI task One"
|
||||
assert placeholders["subentry_type"] == "AI task"
|
||||
|
||||
result = await process_repair_fix_flow(
|
||||
client,
|
||||
flow_id,
|
||||
json={CONF_CHAT_MODEL: "claude-sonnet-4-5"},
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert (
|
||||
_get_subentry(entry_one, "ai_task_data").data[CONF_CHAT_MODEL]
|
||||
== "claude-sonnet-4-5"
|
||||
)
|
||||
assert (
|
||||
_get_subentry(entry_one, "conversation").data[CONF_CHAT_MODEL]
|
||||
== "claude-haiku-4-5"
|
||||
)
|
||||
|
||||
placeholders = result["description_placeholders"]
|
||||
assert placeholders["entry_name"] == entry_two.title
|
||||
assert placeholders["subentry_name"] == "Conversation Two"
|
||||
assert placeholders["subentry_type"] == "Conversation agent"
|
||||
|
||||
result = await process_repair_fix_flow(
|
||||
client,
|
||||
flow_id,
|
||||
json={CONF_CHAT_MODEL: "claude-opus-4-5"},
|
||||
)
|
||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert (
|
||||
_get_subentry(entry_two, "conversation").data[CONF_CHAT_MODEL]
|
||||
== "claude-opus-4-5"
|
||||
)
|
||||
|
||||
assert issue_registry.async_get_issue(DOMAIN, "model_deprecated") is None
|
||||
|
||||
|
||||
async def test_repair_flow_no_deprecated_models(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
issue_registry: ir.IssueRegistry,
|
||||
) -> None:
|
||||
"""Test the repair flow completes when everything was fixed."""
|
||||
_make_entry(
|
||||
hass,
|
||||
title="Entry One",
|
||||
api_key="key-one",
|
||||
subentries_data=[
|
||||
{
|
||||
"data": {CONF_CHAT_MODEL: "claude-sonnet-4-5"},
|
||||
"subentry_type": "conversation",
|
||||
"title": "Conversation One",
|
||||
"unique_id": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
ir.async_create_issue(
|
||||
hass,
|
||||
DOMAIN,
|
||||
"model_deprecated",
|
||||
is_fixable=True,
|
||||
is_persistent=False,
|
||||
severity=ir.IssueSeverity.WARNING,
|
||||
translation_key="model_deprecated",
|
||||
)
|
||||
|
||||
await _setup_repairs(hass)
|
||||
client = await hass_client()
|
||||
|
||||
result = await start_repair_fix_flow(client, DOMAIN, "model_deprecated")
|
||||
|
||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert issue_registry.async_get_issue(DOMAIN, "model_deprecated") is None
|
||||
|
||||
|
||||
async def test_repair_flow_defers_reload_until_entry_done(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
issue_registry: ir.IssueRegistry,
|
||||
) -> None:
|
||||
"""Test reload is deferred until all subentries for an entry are fixed."""
|
||||
entry = _make_entry(
|
||||
hass,
|
||||
title="Entry One",
|
||||
api_key="key-one",
|
||||
subentries_data=[
|
||||
{
|
||||
"data": {CONF_CHAT_MODEL: "claude-3-5-haiku-20241022"},
|
||||
"subentry_type": "conversation",
|
||||
"title": "Conversation One",
|
||||
"unique_id": None,
|
||||
},
|
||||
{
|
||||
"data": {CONF_CHAT_MODEL: "claude-3-5-sonnet-20241022"},
|
||||
"subentry_type": "ai_task_data",
|
||||
"title": "AI task One",
|
||||
"unique_id": None,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
ir.async_create_issue(
|
||||
hass,
|
||||
DOMAIN,
|
||||
"model_deprecated",
|
||||
is_fixable=True,
|
||||
is_persistent=False,
|
||||
severity=ir.IssueSeverity.WARNING,
|
||||
translation_key="model_deprecated",
|
||||
)
|
||||
|
||||
await _setup_repairs(hass)
|
||||
client = await hass_client()
|
||||
|
||||
model_options: list[dict[str, str]] = [
|
||||
{"label": "Claude Haiku 4.5", "value": "claude-haiku-4-5"},
|
||||
{"label": "Claude Sonnet 4.5", "value": "claude-sonnet-4-5"},
|
||||
]
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.anthropic.repairs.get_model_list",
|
||||
new_callable=AsyncMock,
|
||||
return_value=model_options,
|
||||
),
|
||||
patch.object(
|
||||
hass.config_entries,
|
||||
"async_reload",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
) as reload_mock,
|
||||
):
|
||||
result = await start_repair_fix_flow(client, DOMAIN, "model_deprecated")
|
||||
flow_id = result["flow_id"]
|
||||
assert result["step_id"] == "init"
|
||||
|
||||
result = await process_repair_fix_flow(
|
||||
client,
|
||||
flow_id,
|
||||
json={CONF_CHAT_MODEL: "claude-haiku-4-5"},
|
||||
)
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert reload_mock.await_count == 0
|
||||
|
||||
result = await process_repair_fix_flow(
|
||||
client,
|
||||
flow_id,
|
||||
json={CONF_CHAT_MODEL: "claude-sonnet-4-5"},
|
||||
)
|
||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert reload_mock.await_count == 1
|
||||
assert reload_mock.call_args_list == [call(entry.entry_id)]
|
||||
Reference in New Issue
Block a user