Anthropic repair deprecated models (#162162)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Denis Shulyaka
2026-02-04 16:40:31 +03:00
committed by GitHub
parent 54d64b7da2
commit e2469bcd0f
6 changed files with 672 additions and 36 deletions

View File

@@ -14,10 +14,18 @@ from homeassistant.helpers import (
config_validation as cv,
device_registry as dr,
entity_registry as er,
issue_registry as ir,
)
from homeassistant.helpers.typing import ConfigType
from .const import DEFAULT_CONVERSATION_NAME, DOMAIN, LOGGER
from .const import (
CONF_CHAT_MODEL,
DATA_REPAIR_DEFER_RELOAD,
DEFAULT_CONVERSATION_NAME,
DEPRECATED_MODELS,
DOMAIN,
LOGGER,
)
PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
@@ -27,6 +35,7 @@ type AnthropicConfigEntry = ConfigEntry[anthropic.AsyncClient]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up Anthropic."""
hass.data.setdefault(DOMAIN, {}).setdefault(DATA_REPAIR_DEFER_RELOAD, set())
await async_migrate_integration(hass)
return True
@@ -50,6 +59,22 @@ async def async_setup_entry(hass: HomeAssistant, entry: AnthropicConfigEntry) ->
entry.async_on_unload(entry.add_update_listener(async_update_options))
for subentry in entry.subentries.values():
if (model := subentry.data.get(CONF_CHAT_MODEL)) and model.startswith(
tuple(DEPRECATED_MODELS)
):
ir.async_create_issue(
hass,
DOMAIN,
"model_deprecated",
is_fixable=True,
is_persistent=False,
learn_more_url="https://platform.claude.com/docs/en/about-claude/model-deprecations",
severity=ir.IssueSeverity.WARNING,
translation_key="model_deprecated",
)
break
return True
@@ -62,6 +87,11 @@ async def async_update_options(
hass: HomeAssistant, entry: AnthropicConfigEntry
) -> None:
"""Update options."""
defer_reload_entries: set[str] = hass.data.setdefault(DOMAIN, {}).setdefault(
DATA_REPAIR_DEFER_RELOAD, set()
)
if entry.entry_id in defer_reload_entries:
return
await hass.config_entries.async_reload(entry.entry_id)

View File

@@ -92,6 +92,40 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
await client.models.list(timeout=10.0)
async def get_model_list(client: anthropic.AsyncAnthropic) -> list[SelectOptionDict]:
"""Get list of available models."""
try:
models = (await client.models.list()).data
except anthropic.AnthropicError:
models = []
_LOGGER.debug("Available models: %s", models)
model_options: list[SelectOptionDict] = []
short_form = re.compile(r"[^\d]-\d$")
for model_info in models:
# Resolve alias from versioned model name:
model_alias = (
model_info.id[:-9]
if model_info.id
not in (
"claude-3-haiku-20240307",
"claude-3-5-haiku-20241022",
"claude-3-opus-20240229",
)
else model_info.id
)
if short_form.search(model_alias):
model_alias += "-0"
if model_alias.endswith(("haiku", "opus", "sonnet")):
model_alias += "-latest"
model_options.append(
SelectOptionDict(
label=model_info.display_name,
value=model_alias,
)
)
return model_options
class AnthropicConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Anthropic."""
@@ -401,42 +435,13 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
async def _get_model_list(self) -> list[SelectOptionDict]:
"""Get list of available models."""
try:
client = await self.hass.async_add_executor_job(
partial(
anthropic.AsyncAnthropic,
api_key=self._get_entry().data[CONF_API_KEY],
)
client = await self.hass.async_add_executor_job(
partial(
anthropic.AsyncAnthropic,
api_key=self._get_entry().data[CONF_API_KEY],
)
models = (await client.models.list()).data
except anthropic.AnthropicError:
models = []
_LOGGER.debug("Available models: %s", models)
model_options: list[SelectOptionDict] = []
short_form = re.compile(r"[^\d]-\d$")
for model_info in models:
# Resolve alias from versioned model name:
model_alias = (
model_info.id[:-9]
if model_info.id
not in (
"claude-3-haiku-20240307",
"claude-3-5-haiku-20241022",
"claude-3-opus-20240229",
)
else model_info.id
)
if short_form.search(model_alias):
model_alias += "-0"
if model_alias.endswith(("haiku", "opus", "sonnet")):
model_alias += "-latest"
model_options.append(
SelectOptionDict(
label=model_info.display_name,
value=model_alias,
)
)
return model_options
)
return await get_model_list(client)
async def _get_location_data(self) -> dict[str, str]:
"""Get approximate location data of the user."""

View File

@@ -22,6 +22,8 @@ CONF_WEB_SEARCH_REGION = "region"
CONF_WEB_SEARCH_COUNTRY = "country"
CONF_WEB_SEARCH_TIMEZONE = "timezone"
DATA_REPAIR_DEFER_RELOAD = "repair_defer_reload"
DEFAULT = {
CONF_CHAT_MODEL: "claude-haiku-4-5",
CONF_MAX_TOKENS: 3000,
@@ -46,3 +48,10 @@ WEB_SEARCH_UNSUPPORTED_MODELS = [
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
]
DEPRECATED_MODELS = [
"claude-3-5-haiku",
"claude-3-7-sonnet",
"claude-3-5-sonnet",
"claude-3-opus",
]

View File

@@ -0,0 +1,275 @@
"""Issue repair flow for Anthropic."""
from __future__ import annotations
from collections.abc import Iterator
from typing import cast
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.components.repairs import RepairsFlow
from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.selector import SelectSelector, SelectSelectorConfig
from .config_flow import get_model_list
from .const import (
CONF_CHAT_MODEL,
DATA_REPAIR_DEFER_RELOAD,
DEFAULT,
DEPRECATED_MODELS,
DOMAIN,
)
class ModelDeprecatedRepairFlow(RepairsFlow):
"""Handler for an issue fixing flow."""
_subentry_iter: Iterator[tuple[str, str]] | None
_current_entry_id: str | None
_current_subentry_id: str | None
_reload_pending: set[str]
_pending_updates: dict[str, dict[str, str]]
def __init__(self) -> None:
"""Initialize the flow."""
super().__init__()
self._subentry_iter = None
self._current_entry_id = None
self._current_subentry_id = None
self._reload_pending = set()
self._pending_updates = {}
async def async_step_init(
self, user_input: dict[str, str] | None = None
) -> data_entry_flow.FlowResult:
"""Handle the first step of a fix flow."""
previous_entry_id: str | None = None
if user_input is not None:
previous_entry_id = self._async_update_current_subentry(user_input)
self._clear_current_target()
target = await self._async_next_target()
next_entry_id = target[0].entry_id if target else None
if previous_entry_id and previous_entry_id != next_entry_id:
await self._async_apply_pending_updates(previous_entry_id)
if target is None:
await self._async_apply_all_pending_updates()
return self.async_create_entry(data={})
entry, subentry, model = target
client = entry.runtime_data
model_list = [
model_option
for model_option in await get_model_list(client)
if not model_option["value"].startswith(tuple(DEPRECATED_MODELS))
]
if "opus" in model:
suggested_model = "claude-opus-4-5"
elif "haiku" in model:
suggested_model = "claude-haiku-4-5"
elif "sonnet" in model:
suggested_model = "claude-sonnet-4-5"
else:
suggested_model = cast(str, DEFAULT[CONF_CHAT_MODEL])
schema = vol.Schema(
{
vol.Required(
CONF_CHAT_MODEL,
default=suggested_model,
): SelectSelector(
SelectSelectorConfig(options=model_list, custom_value=True)
),
}
)
return self.async_show_form(
step_id="init",
data_schema=schema,
description_placeholders={
"entry_name": entry.title,
"model": model,
"subentry_name": subentry.title,
"subentry_type": self._format_subentry_type(subentry.subentry_type),
},
)
def _iter_deprecated_subentries(self) -> Iterator[tuple[str, str]]:
"""Yield entry/subentry pairs that use deprecated models."""
for entry in self.hass.config_entries.async_entries(DOMAIN):
if entry.state is not ConfigEntryState.LOADED:
continue
for subentry in entry.subentries.values():
model = subentry.data.get(CONF_CHAT_MODEL)
if model and model.startswith(tuple(DEPRECATED_MODELS)):
yield entry.entry_id, subentry.subentry_id
async def _async_next_target(
self,
) -> tuple[ConfigEntry, ConfigSubentry, str] | None:
"""Return the next deprecated subentry target."""
if self._subentry_iter is None:
self._subentry_iter = self._iter_deprecated_subentries()
while True:
try:
entry_id, subentry_id = next(self._subentry_iter)
except StopIteration:
return None
entry = self.hass.config_entries.async_get_entry(entry_id)
if entry is None:
continue
subentry = entry.subentries.get(subentry_id)
if subentry is None:
continue
model = self._pending_model(entry_id, subentry_id)
if model is None:
model = subentry.data.get(CONF_CHAT_MODEL)
if not model or not model.startswith(tuple(DEPRECATED_MODELS)):
continue
self._current_entry_id = entry_id
self._current_subentry_id = subentry_id
return entry, subentry, model
def _async_update_current_subentry(self, user_input: dict[str, str]) -> str | None:
"""Update the currently selected subentry."""
if not self._current_entry_id or not self._current_subentry_id:
return None
entry = self.hass.config_entries.async_get_entry(self._current_entry_id)
if entry is None:
return None
subentry = entry.subentries.get(self._current_subentry_id)
if subentry is None:
return None
updated_data = {
**subentry.data,
CONF_CHAT_MODEL: user_input[CONF_CHAT_MODEL],
}
if updated_data == subentry.data:
return entry.entry_id
self._queue_pending_update(
entry.entry_id,
subentry.subentry_id,
updated_data[CONF_CHAT_MODEL],
)
return entry.entry_id
def _clear_current_target(self) -> None:
"""Clear current target tracking."""
self._current_entry_id = None
self._current_subentry_id = None
def _format_subentry_type(self, subentry_type: str) -> str:
"""Return a user-friendly subentry type label."""
if subentry_type == "conversation":
return "Conversation agent"
if subentry_type in ("ai_task", "ai_task_data"):
return "AI task"
return subentry_type
def _queue_pending_update(
self, entry_id: str, subentry_id: str, model: str
) -> None:
"""Store a pending model update for a subentry."""
self._pending_updates.setdefault(entry_id, {})[subentry_id] = model
def _pending_model(self, entry_id: str, subentry_id: str) -> str | None:
"""Return a pending model update if one exists."""
return self._pending_updates.get(entry_id, {}).get(subentry_id)
def _mark_entry_for_reload(self, entry_id: str) -> None:
"""Prevent reload until repairs are complete for the entry."""
self._reload_pending.add(entry_id)
defer_reload_entries: set[str] = self.hass.data.setdefault(
DOMAIN, {}
).setdefault(DATA_REPAIR_DEFER_RELOAD, set())
defer_reload_entries.add(entry_id)
async def _async_reload_entry(self, entry_id: str) -> None:
"""Reload an entry once all repairs are completed."""
if entry_id not in self._reload_pending:
return
entry = self.hass.config_entries.async_get_entry(entry_id)
if entry is not None and entry.state is not ConfigEntryState.LOADED:
self._clear_defer_reload(entry_id)
self._reload_pending.discard(entry_id)
return
if entry is not None:
await self.hass.config_entries.async_reload(entry_id)
self._clear_defer_reload(entry_id)
self._reload_pending.discard(entry_id)
def _clear_defer_reload(self, entry_id: str) -> None:
"""Remove entry from the deferred reload set."""
defer_reload_entries: set[str] = self.hass.data.setdefault(
DOMAIN, {}
).setdefault(DATA_REPAIR_DEFER_RELOAD, set())
defer_reload_entries.discard(entry_id)
async def _async_apply_pending_updates(self, entry_id: str) -> None:
"""Apply pending subentry updates for a single entry."""
updates = self._pending_updates.pop(entry_id, None)
if not updates:
return
entry = self.hass.config_entries.async_get_entry(entry_id)
if entry is None or entry.state is not ConfigEntryState.LOADED:
return
changed = False
for subentry_id, model in updates.items():
subentry = entry.subentries.get(subentry_id)
if subentry is None:
continue
updated_data = {
**subentry.data,
CONF_CHAT_MODEL: model,
}
if updated_data == subentry.data:
continue
if not changed:
self._mark_entry_for_reload(entry_id)
changed = True
self.hass.config_entries.async_update_subentry(
entry,
subentry,
data=updated_data,
)
if not changed:
return
await self._async_reload_entry(entry_id)
async def _async_apply_all_pending_updates(self) -> None:
"""Apply all pending updates across entries."""
for entry_id in list(self._pending_updates):
await self._async_apply_pending_updates(entry_id)
async def async_create_fix_flow(
hass: HomeAssistant,
issue_id: str,
data: dict[str, str | int | float | None] | None,
) -> RepairsFlow:
"""Create flow."""
if issue_id == "model_deprecated":
return ModelDeprecatedRepairFlow()
raise HomeAssistantError("Unknown issue ID")

View File

@@ -109,5 +109,21 @@
}
}
}
},
"issues": {
"model_deprecated": {
"fix_flow": {
"step": {
"init": {
"data": {
"chat_model": "[%key:common::generic::model%]"
},
"description": "You are updating {subentry_name} ({subentry_type}) in {entry_name}. The current model {model} is deprecated. Select a supported model to continue.",
"title": "Update model"
}
}
},
"title": "Model deprecated"
}
}
}

View File

@@ -0,0 +1,301 @@
"""Tests for the Anthropic repairs flow."""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock, call, patch
from homeassistant.components.anthropic.const import CONF_CHAT_MODEL, DOMAIN
from homeassistant.config_entries import ConfigEntryState, ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import issue_registry as ir
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry
from tests.components.repairs import (
async_process_repairs_platforms,
process_repair_fix_flow,
start_repair_fix_flow,
)
from tests.typing import ClientSessionGenerator
def _make_entry(
hass: HomeAssistant,
*,
title: str,
api_key: str,
subentries_data: list[dict[str, Any]],
) -> MockConfigEntry:
"""Create a config entry with subentries and runtime data."""
entry = MockConfigEntry(
domain=DOMAIN,
title=title,
data={"api_key": api_key},
version=2,
subentries_data=subentries_data,
)
entry.add_to_hass(hass)
object.__setattr__(entry, "state", ConfigEntryState.LOADED)
entry.runtime_data = MagicMock()
return entry
def _get_subentry(
entry: MockConfigEntry,
subentry_type: str,
) -> ConfigSubentry:
"""Return the first subentry of a type."""
return next(
subentry
for subentry in entry.subentries.values()
if subentry.subentry_type == subentry_type
)
async def _setup_repairs(hass: HomeAssistant) -> None:
hass.config.components.add(DOMAIN)
assert await async_setup_component(hass, "repairs", {})
await async_process_repairs_platforms(hass)
async def test_repair_flow_iterates_subentries(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
issue_registry: ir.IssueRegistry,
) -> None:
"""Test the repair flow iterates across deprecated subentries."""
entry_one: MockConfigEntry = _make_entry(
hass,
title="Entry One",
api_key="key-one",
subentries_data=[
{
"data": {CONF_CHAT_MODEL: "claude-3-5-haiku-20241022"},
"subentry_type": "conversation",
"title": "Conversation One",
"unique_id": None,
},
{
"data": {CONF_CHAT_MODEL: "claude-3-5-sonnet-20241022"},
"subentry_type": "ai_task_data",
"title": "AI task One",
"unique_id": None,
},
],
)
entry_two: MockConfigEntry = _make_entry(
hass,
title="Entry Two",
api_key="key-two",
subentries_data=[
{
"data": {CONF_CHAT_MODEL: "claude-3-opus-20240229"},
"subentry_type": "conversation",
"title": "Conversation Two",
"unique_id": None,
},
],
)
ir.async_create_issue(
hass,
DOMAIN,
"model_deprecated",
is_fixable=True,
is_persistent=False,
severity=ir.IssueSeverity.WARNING,
translation_key="model_deprecated",
)
await _setup_repairs(hass)
client = await hass_client()
model_options: list[dict[str, str]] = [
{"label": "Claude Haiku 4.5", "value": "claude-haiku-4-5"},
{"label": "Claude Sonnet 4.5", "value": "claude-sonnet-4-5"},
{"label": "Claude Opus 4.5", "value": "claude-opus-4-5"},
]
with patch(
"homeassistant.components.anthropic.repairs.get_model_list",
new_callable=AsyncMock,
return_value=model_options,
):
result = await start_repair_fix_flow(client, DOMAIN, "model_deprecated")
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "init"
placeholders = result["description_placeholders"]
assert placeholders["entry_name"] == entry_one.title
assert placeholders["subentry_name"] == "Conversation One"
assert placeholders["subentry_type"] == "Conversation agent"
flow_id = result["flow_id"]
result = await process_repair_fix_flow(
client,
flow_id,
json={CONF_CHAT_MODEL: "claude-haiku-4-5"},
)
assert result["type"] == FlowResultType.FORM
assert (
_get_subentry(entry_one, "conversation").data[CONF_CHAT_MODEL]
== "claude-3-5-haiku-20241022"
)
placeholders = result["description_placeholders"]
assert placeholders["entry_name"] == entry_one.title
assert placeholders["subentry_name"] == "AI task One"
assert placeholders["subentry_type"] == "AI task"
result = await process_repair_fix_flow(
client,
flow_id,
json={CONF_CHAT_MODEL: "claude-sonnet-4-5"},
)
assert result["type"] == FlowResultType.FORM
assert (
_get_subentry(entry_one, "ai_task_data").data[CONF_CHAT_MODEL]
== "claude-sonnet-4-5"
)
assert (
_get_subentry(entry_one, "conversation").data[CONF_CHAT_MODEL]
== "claude-haiku-4-5"
)
placeholders = result["description_placeholders"]
assert placeholders["entry_name"] == entry_two.title
assert placeholders["subentry_name"] == "Conversation Two"
assert placeholders["subentry_type"] == "Conversation agent"
result = await process_repair_fix_flow(
client,
flow_id,
json={CONF_CHAT_MODEL: "claude-opus-4-5"},
)
assert result["type"] == FlowResultType.CREATE_ENTRY
assert (
_get_subentry(entry_two, "conversation").data[CONF_CHAT_MODEL]
== "claude-opus-4-5"
)
assert issue_registry.async_get_issue(DOMAIN, "model_deprecated") is None
async def test_repair_flow_no_deprecated_models(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
issue_registry: ir.IssueRegistry,
) -> None:
"""Test the repair flow completes when everything was fixed."""
_make_entry(
hass,
title="Entry One",
api_key="key-one",
subentries_data=[
{
"data": {CONF_CHAT_MODEL: "claude-sonnet-4-5"},
"subentry_type": "conversation",
"title": "Conversation One",
"unique_id": None,
}
],
)
ir.async_create_issue(
hass,
DOMAIN,
"model_deprecated",
is_fixable=True,
is_persistent=False,
severity=ir.IssueSeverity.WARNING,
translation_key="model_deprecated",
)
await _setup_repairs(hass)
client = await hass_client()
result = await start_repair_fix_flow(client, DOMAIN, "model_deprecated")
assert result["type"] == FlowResultType.CREATE_ENTRY
assert issue_registry.async_get_issue(DOMAIN, "model_deprecated") is None
async def test_repair_flow_defers_reload_until_entry_done(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
issue_registry: ir.IssueRegistry,
) -> None:
"""Test reload is deferred until all subentries for an entry are fixed."""
entry = _make_entry(
hass,
title="Entry One",
api_key="key-one",
subentries_data=[
{
"data": {CONF_CHAT_MODEL: "claude-3-5-haiku-20241022"},
"subentry_type": "conversation",
"title": "Conversation One",
"unique_id": None,
},
{
"data": {CONF_CHAT_MODEL: "claude-3-5-sonnet-20241022"},
"subentry_type": "ai_task_data",
"title": "AI task One",
"unique_id": None,
},
],
)
ir.async_create_issue(
hass,
DOMAIN,
"model_deprecated",
is_fixable=True,
is_persistent=False,
severity=ir.IssueSeverity.WARNING,
translation_key="model_deprecated",
)
await _setup_repairs(hass)
client = await hass_client()
model_options: list[dict[str, str]] = [
{"label": "Claude Haiku 4.5", "value": "claude-haiku-4-5"},
{"label": "Claude Sonnet 4.5", "value": "claude-sonnet-4-5"},
]
with (
patch(
"homeassistant.components.anthropic.repairs.get_model_list",
new_callable=AsyncMock,
return_value=model_options,
),
patch.object(
hass.config_entries,
"async_reload",
new_callable=AsyncMock,
return_value=True,
) as reload_mock,
):
result = await start_repair_fix_flow(client, DOMAIN, "model_deprecated")
flow_id = result["flow_id"]
assert result["step_id"] == "init"
result = await process_repair_fix_flow(
client,
flow_id,
json={CONF_CHAT_MODEL: "claude-haiku-4-5"},
)
assert result["type"] == FlowResultType.FORM
assert reload_mock.await_count == 0
result = await process_repair_fix_flow(
client,
flow_id,
json={CONF_CHAT_MODEL: "claude-sonnet-4-5"},
)
assert result["type"] == FlowResultType.CREATE_ENTRY
assert reload_mock.await_count == 1
assert reload_mock.call_args_list == [call(entry.entry_id)]