diff --git a/homeassistant/components/conversation/agent_manager.py b/homeassistant/components/conversation/agent_manager.py index 9f31ccd6c62..aa8b7644900 100644 --- a/homeassistant/components/conversation/agent_manager.py +++ b/homeassistant/components/conversation/agent_manager.py @@ -2,6 +2,7 @@ from __future__ import annotations +import dataclasses import logging from typing import Any @@ -20,6 +21,11 @@ from .models import ( ConversationInput, ConversationResult, ) +from .trace import ( + ConversationTraceEvent, + ConversationTraceEventType, + async_conversation_trace, +) _LOGGER = logging.getLogger(__name__) @@ -84,15 +90,23 @@ async def async_converse( language = hass.config.language _LOGGER.debug("Processing in %s: %s", language, text) - return await method( - ConversationInput( - text=text, - context=context, - conversation_id=conversation_id, - device_id=device_id, - language=language, - ) + conversation_input = ConversationInput( + text=text, + context=context, + conversation_id=conversation_id, + device_id=device_id, + language=language, ) + with async_conversation_trace() as trace: + trace.add_event( + ConversationTraceEvent( + ConversationTraceEventType.ASYNC_PROCESS, + dataclasses.asdict(conversation_input), + ) + ) + result = await method(conversation_input) + trace.set_result(**result.as_dict()) + return result class AgentManager: diff --git a/homeassistant/components/conversation/trace.py b/homeassistant/components/conversation/trace.py new file mode 100644 index 00000000000..0bd2fe8ed5b --- /dev/null +++ b/homeassistant/components/conversation/trace.py @@ -0,0 +1,118 @@ +"""Debug traces for conversation.""" + +from collections.abc import Generator +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import asdict, dataclass, field +import enum +from typing import Any + +from homeassistant.util import dt as dt_util, ulid as ulid_util +from homeassistant.util.limited_size_dict import LimitedSizeDict + +STORED_TRACES = 3 + + +class ConversationTraceEventType(enum.StrEnum): + """Type of an event emitted during a conversation.""" + + ASYNC_PROCESS = "async_process" + """The conversation is started from user input.""" + + AGENT_DETAIL = "agent_detail" + """Event detail added by a conversation agent.""" + + LLM_TOOL_CALL = "llm_tool_call" + """An LLM Tool call""" + + +@dataclass(frozen=True) +class ConversationTraceEvent: + """Event emitted during a conversation.""" + + event_type: ConversationTraceEventType + data: dict[str, Any] | None = None + timestamp: str = field(default_factory=lambda: dt_util.utcnow().isoformat()) + + +class ConversationTrace: + """Stores debug data related to a conversation.""" + + def __init__(self) -> None: + """Initialize ConversationTrace.""" + self._trace_id = ulid_util.ulid_now() + self._events: list[ConversationTraceEvent] = [] + self._error: Exception | None = None + self._result: dict[str, Any] = {} + + @property + def trace_id(self) -> str: + """Identifier for this trace.""" + return self._trace_id + + def add_event(self, event: ConversationTraceEvent) -> None: + """Add an event to the trace.""" + self._events.append(event) + + def set_error(self, ex: Exception) -> None: + """Set error.""" + self._error = ex + + def set_result(self, **kwargs: Any) -> None: + """Set result.""" + self._result = {**kwargs} + + def as_dict(self) -> dict[str, Any]: + """Return dictionary version of this ConversationTrace.""" + result: dict[str, Any] = { + "id": self._trace_id, + "events": [asdict(event) for event in self._events], + } + if self._error is not None: + result["error"] = str(self._error) or self._error.__class__.__name__ + if self._result is not None: + result["result"] = self._result + return result + + +_current_trace: ContextVar[ConversationTrace | None] = ContextVar( + "current_trace", default=None +) +_recent_traces: LimitedSizeDict[str, ConversationTrace] = LimitedSizeDict( + size_limit=STORED_TRACES +) + + +def async_conversation_trace_append( + event_type: ConversationTraceEventType, event_data: dict[str, Any] +) -> None: + """Append a ConversationTraceEvent to the current active trace.""" + trace = _current_trace.get() + if not trace: + return + trace.add_event(ConversationTraceEvent(event_type, event_data)) + + +@contextmanager +def async_conversation_trace() -> Generator[ConversationTrace, None]: + """Create a new active ConversationTrace.""" + trace = ConversationTrace() + token = _current_trace.set(trace) + _recent_traces[trace.trace_id] = trace + try: + yield trace + except Exception as ex: + trace.set_error(ex) + raise + finally: + _current_trace.reset(token) + + +def async_get_traces() -> list[ConversationTrace]: + """Get the most recent traces.""" + return list(_recent_traces.values()) + + +def async_clear_traces() -> None: + """Clear all traces.""" + _recent_traces.clear() diff --git a/homeassistant/components/envisalink/manifest.json b/homeassistant/components/envisalink/manifest.json index 093ebf77eba..0cf9f165aa2 100644 --- a/homeassistant/components/envisalink/manifest.json +++ b/homeassistant/components/envisalink/manifest.json @@ -5,5 +5,5 @@ "documentation": "https://www.home-assistant.io/integrations/envisalink", "iot_class": "local_push", "loggers": ["pyenvisalink"], - "requirements": ["pyenvisalink==4.6"] + "requirements": ["pyenvisalink==4.7"] } diff --git a/homeassistant/components/firmata/__init__.py b/homeassistant/components/firmata/__init__.py index 283fd585d35..26fbe596aa8 100644 --- a/homeassistant/components/firmata/__init__.py +++ b/homeassistant/components/firmata/__init__.py @@ -1,6 +1,5 @@ """Support for Arduino-compatible Microcontrollers through Firmata.""" -import asyncio from copy import copy import logging @@ -212,16 +211,15 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: """Shutdown and close a Firmata board for a config entry.""" _LOGGER.debug("Closing Firmata board %s", config_entry.data[CONF_NAME]) - - unload_entries = [] - for conf, platform in CONF_PLATFORM_MAP.items(): - if conf in config_entry.data: - unload_entries.append( - hass.config_entries.async_forward_entry_unload(config_entry, platform) - ) - results = [] - if unload_entries: - results = await asyncio.gather(*unload_entries) + results: list[bool] = [] + if platforms := [ + platform + for conf, platform in CONF_PLATFORM_MAP.items() + if conf in config_entry.data + ]: + results.append( + await hass.config_entries.async_unload_platforms(config_entry, platforms) + ) results.append(await hass.data[DOMAIN].pop(config_entry.entry_id).async_reset()) return False not in results diff --git a/homeassistant/components/forecast_solar/__init__.py b/homeassistant/components/forecast_solar/__init__.py index f4cb1d0a631..00be13f1235 100644 --- a/homeassistant/components/forecast_solar/__init__.py +++ b/homeassistant/components/forecast_solar/__init__.py @@ -11,12 +11,13 @@ from .const import ( CONF_DAMPING_EVENING, CONF_DAMPING_MORNING, CONF_MODULES_POWER, - DOMAIN, ) from .coordinator import ForecastSolarDataUpdateCoordinator PLATFORMS = [Platform.SENSOR] +type ForecastSolarConfigEntry = ConfigEntry[ForecastSolarDataUpdateCoordinator] + async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Migrate old config entry.""" @@ -36,12 +37,14 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return True -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_setup_entry( + hass: HomeAssistant, entry: ForecastSolarConfigEntry +) -> bool: """Set up Forecast.Solar from a config entry.""" coordinator = ForecastSolarDataUpdateCoordinator(hass, entry) await coordinator.async_config_entry_first_refresh() - hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator + entry.runtime_data = coordinator await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) @@ -52,11 +55,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" - unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) - if unload_ok: - hass.data[DOMAIN].pop(entry.entry_id) - - return unload_ok + return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) async def async_update_options(hass: HomeAssistant, entry: ConfigEntry) -> None: diff --git a/homeassistant/components/forecast_solar/diagnostics.py b/homeassistant/components/forecast_solar/diagnostics.py index a9bcebdb3cd..cb33ac5dc5a 100644 --- a/homeassistant/components/forecast_solar/diagnostics.py +++ b/homeassistant/components/forecast_solar/diagnostics.py @@ -4,15 +4,11 @@ from __future__ import annotations from typing import Any -from forecast_solar import Estimate - from homeassistant.components.diagnostics import async_redact_data -from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_API_KEY, CONF_LATITUDE, CONF_LONGITUDE from homeassistant.core import HomeAssistant -from homeassistant.helpers.update_coordinator import DataUpdateCoordinator -from .const import DOMAIN +from . import ForecastSolarConfigEntry TO_REDACT = { CONF_API_KEY, @@ -22,10 +18,10 @@ TO_REDACT = { async def async_get_config_entry_diagnostics( - hass: HomeAssistant, entry: ConfigEntry + hass: HomeAssistant, entry: ForecastSolarConfigEntry ) -> dict[str, Any]: """Return diagnostics for a config entry.""" - coordinator: DataUpdateCoordinator[Estimate] = hass.data[DOMAIN][entry.entry_id] + coordinator = entry.runtime_data return { "entry": { diff --git a/homeassistant/components/forecast_solar/energy.py b/homeassistant/components/forecast_solar/energy.py index f4d03f26299..9031e5c1e1d 100644 --- a/homeassistant/components/forecast_solar/energy.py +++ b/homeassistant/components/forecast_solar/energy.py @@ -4,19 +4,21 @@ from __future__ import annotations from homeassistant.core import HomeAssistant -from .const import DOMAIN +from .coordinator import ForecastSolarDataUpdateCoordinator async def async_get_solar_forecast( hass: HomeAssistant, config_entry_id: str ) -> dict[str, dict[str, float | int]] | None: """Get solar forecast for a config entry ID.""" - if (coordinator := hass.data[DOMAIN].get(config_entry_id)) is None: + if ( + entry := hass.config_entries.async_get_entry(config_entry_id) + ) is None or not isinstance(entry.runtime_data, ForecastSolarDataUpdateCoordinator): return None return { "wh_hours": { timestamp.isoformat(): val - for timestamp, val in coordinator.data.wh_period.items() + for timestamp, val in entry.runtime_data.data.wh_period.items() } } diff --git a/homeassistant/components/forecast_solar/sensor.py b/homeassistant/components/forecast_solar/sensor.py index 8d35b38765a..c1fa971a89d 100644 --- a/homeassistant/components/forecast_solar/sensor.py +++ b/homeassistant/components/forecast_solar/sensor.py @@ -16,7 +16,6 @@ from homeassistant.components.sensor import ( SensorEntityDescription, SensorStateClass, ) -from homeassistant.config_entries import ConfigEntry from homeassistant.const import UnitOfEnergy, UnitOfPower from homeassistant.core import HomeAssistant from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo @@ -24,6 +23,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import StateType from homeassistant.helpers.update_coordinator import CoordinatorEntity +from . import ForecastSolarConfigEntry from .const import DOMAIN from .coordinator import ForecastSolarDataUpdateCoordinator @@ -133,10 +133,12 @@ SENSORS: tuple[ForecastSolarSensorEntityDescription, ...] = ( async def async_setup_entry( - hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback + hass: HomeAssistant, + entry: ForecastSolarConfigEntry, + async_add_entities: AddEntitiesCallback, ) -> None: """Defer sensor setup to the shared sensor module.""" - coordinator: ForecastSolarDataUpdateCoordinator = hass.data[DOMAIN][entry.entry_id] + coordinator = entry.runtime_data async_add_entities( ForecastSolarSensorEntity( diff --git a/homeassistant/components/fritz/button.py b/homeassistant/components/fritz/button.py index a0cbd54eaac..263521d23f4 100644 --- a/homeassistant/components/fritz/button.py +++ b/homeassistant/components/fritz/button.py @@ -5,7 +5,7 @@ from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass import logging -from typing import Final +from typing import Any, Final from homeassistant.components.button import ( ButtonDeviceClass, @@ -30,7 +30,7 @@ _LOGGER = logging.getLogger(__name__) class FritzButtonDescription(ButtonEntityDescription): """Class to describe a Button entity.""" - press_action: Callable + press_action: Callable[[AvmWrapper], Any] BUTTONS: Final = [ diff --git a/homeassistant/components/fritz/const.py b/homeassistant/components/fritz/const.py index 3794a83dd7f..9a266507c25 100644 --- a/homeassistant/components/fritz/const.py +++ b/homeassistant/components/fritz/const.py @@ -57,9 +57,6 @@ ERROR_UPNP_NOT_CONFIGURED = "upnp_not_configured" ERROR_UNKNOWN = "unknown_error" FRITZ_SERVICES = "fritz_services" -SERVICE_REBOOT = "reboot" -SERVICE_RECONNECT = "reconnect" -SERVICE_CLEANUP = "cleanup" SERVICE_SET_GUEST_WIFI_PW = "set_guest_wifi_password" SWITCH_TYPE_DEFLECTION = "CallDeflection" diff --git a/homeassistant/components/fritz/coordinator.py b/homeassistant/components/fritz/coordinator.py index 7256085b93a..299679e642a 100644 --- a/homeassistant/components/fritz/coordinator.py +++ b/homeassistant/components/fritz/coordinator.py @@ -46,9 +46,6 @@ from .const import ( DEFAULT_USERNAME, DOMAIN, FRITZ_EXCEPTIONS, - SERVICE_CLEANUP, - SERVICE_REBOOT, - SERVICE_RECONNECT, SERVICE_SET_GUEST_WIFI_PW, MeshRoles, ) @@ -730,30 +727,6 @@ class FritzBoxTools(DataUpdateCoordinator[UpdateCoordinatorDataType]): ) try: - if service_call.service == SERVICE_REBOOT: - _LOGGER.warning( - 'Service "fritz.reboot" is deprecated, please use the corresponding' - " button entity instead" - ) - await self.async_trigger_reboot() - return - - if service_call.service == SERVICE_RECONNECT: - _LOGGER.warning( - 'Service "fritz.reconnect" is deprecated, please use the' - " corresponding button entity instead" - ) - await self.async_trigger_reconnect() - return - - if service_call.service == SERVICE_CLEANUP: - _LOGGER.warning( - 'Service "fritz.cleanup" is deprecated, please use the' - " corresponding button entity instead" - ) - await self.async_trigger_cleanup(config_entry) - return - if service_call.service == SERVICE_SET_GUEST_WIFI_PW: await self.async_trigger_set_guest_password( service_call.data.get("password"), diff --git a/homeassistant/components/fritz/services.py b/homeassistant/components/fritz/services.py index bd1f3136b01..bace7480ba5 100644 --- a/homeassistant/components/fritz/services.py +++ b/homeassistant/components/fritz/services.py @@ -11,14 +11,7 @@ from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.service import async_extract_config_entry_ids -from .const import ( - DOMAIN, - FRITZ_SERVICES, - SERVICE_CLEANUP, - SERVICE_REBOOT, - SERVICE_RECONNECT, - SERVICE_SET_GUEST_WIFI_PW, -) +from .const import DOMAIN, FRITZ_SERVICES, SERVICE_SET_GUEST_WIFI_PW from .coordinator import AvmWrapper _LOGGER = logging.getLogger(__name__) @@ -32,9 +25,6 @@ SERVICE_SCHEMA_SET_GUEST_WIFI_PW = vol.Schema( ) SERVICE_LIST: list[tuple[str, vol.Schema | None]] = [ - (SERVICE_CLEANUP, None), - (SERVICE_REBOOT, None), - (SERVICE_RECONNECT, None), (SERVICE_SET_GUEST_WIFI_PW, SERVICE_SCHEMA_SET_GUEST_WIFI_PW), ] diff --git a/homeassistant/components/fritz/services.yaml b/homeassistant/components/fritz/services.yaml index b9828280aa2..0ac7ca20c3d 100644 --- a/homeassistant/components/fritz/services.yaml +++ b/homeassistant/components/fritz/services.yaml @@ -1,31 +1,3 @@ -reconnect: - fields: - device_id: - required: true - selector: - device: - integration: fritz - entity: - device_class: connectivity -reboot: - fields: - device_id: - required: true - selector: - device: - integration: fritz - entity: - device_class: connectivity - -cleanup: - fields: - device_id: - required: true - selector: - device: - integration: fritz - entity: - device_class: connectivity set_guest_wifi_password: fields: device_id: diff --git a/homeassistant/components/fritz/strings.json b/homeassistant/components/fritz/strings.json index 30603ca9032..eb47f76f27e 100644 --- a/homeassistant/components/fritz/strings.json +++ b/homeassistant/components/fritz/strings.json @@ -144,42 +144,12 @@ } }, "services": { - "reconnect": { - "name": "[%key:component::fritz::entity::button::reconnect::name%]", - "description": "Reconnects your FRITZ!Box internet connection.", - "fields": { - "device_id": { - "name": "Fritz!Box Device", - "description": "Select the Fritz!Box to reconnect." - } - } - }, - "reboot": { - "name": "Reboot", - "description": "Reboots your FRITZ!Box.", - "fields": { - "device_id": { - "name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]", - "description": "Select the Fritz!Box to reboot." - } - } - }, - "cleanup": { - "name": "Remove stale device tracker entities", - "description": "Remove FRITZ!Box stale device_tracker entities.", - "fields": { - "device_id": { - "name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]", - "description": "Select the Fritz!Box to check." - } - } - }, "set_guest_wifi_password": { "name": "Set guest Wi-Fi password", "description": "Sets a new password for the guest Wi-Fi. The password must be between 8 and 63 characters long. If no additional parameter is set, the password will be auto-generated with a length of 12 characters.", "fields": { "device_id": { - "name": "[%key:component::fritz::services::reconnect::fields::device_id::name%]", + "name": "Fritz!Box Device", "description": "Select the Fritz!Box to configure." }, "password": { diff --git a/homeassistant/components/fronius/diagnostics.py b/homeassistant/components/fronius/diagnostics.py new file mode 100644 index 00000000000..17737ba31f8 --- /dev/null +++ b/homeassistant/components/fronius/diagnostics.py @@ -0,0 +1,46 @@ +"""Diagnostics support for Fronius.""" + +from typing import Any + +from homeassistant.components.diagnostics import async_redact_data +from homeassistant.core import HomeAssistant + +from . import FroniusConfigEntry + +TO_REDACT = {"unique_id", "unique_identifier", "serial"} + + +async def async_get_config_entry_diagnostics( + hass: HomeAssistant, config_entry: FroniusConfigEntry +) -> dict[str, Any]: + """Return diagnostics for a config entry.""" + diag: dict[str, Any] = {} + solar_net = config_entry.runtime_data + fronius = solar_net.fronius + + diag["config_entry"] = config_entry.as_dict() + diag["inverter_info"] = await fronius.inverter_info() + + diag["coordinators"] = {"inverters": {}} + for inv in solar_net.inverter_coordinators: + diag["coordinators"]["inverters"] |= inv.data + + diag["coordinators"]["logger"] = ( + solar_net.logger_coordinator.data if solar_net.logger_coordinator else None + ) + diag["coordinators"]["meter"] = ( + solar_net.meter_coordinator.data if solar_net.meter_coordinator else None + ) + diag["coordinators"]["ohmpilot"] = ( + solar_net.ohmpilot_coordinator.data if solar_net.ohmpilot_coordinator else None + ) + diag["coordinators"]["power_flow"] = ( + solar_net.power_flow_coordinator.data + if solar_net.power_flow_coordinator + else None + ) + diag["coordinators"]["storage"] = ( + solar_net.storage_coordinator.data if solar_net.storage_coordinator else None + ) + + return async_redact_data(diag, TO_REDACT) diff --git a/homeassistant/components/google_generative_ai_conversation/config_flow.py b/homeassistant/components/google_generative_ai_conversation/config_flow.py index 50b626f553c..b559888cc5f 100644 --- a/homeassistant/components/google_generative_ai_conversation/config_flow.py +++ b/homeassistant/components/google_generative_ai_conversation/config_flow.py @@ -181,8 +181,7 @@ async def google_generative_ai_config_option_schema( schema = { vol.Optional( CONF_PROMPT, - description={"suggested_value": options.get(CONF_PROMPT)}, - default=DEFAULT_PROMPT, + description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)}, ): TemplateSelector(), vol.Optional( CONF_LLM_HASS_API, diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py index 549883d4fb9..a83ffed2d88 100644 --- a/homeassistant/components/google_generative_ai_conversation/const.py +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -22,4 +22,4 @@ CONF_HARASSMENT_BLOCK_THRESHOLD = "harassment_block_threshold" CONF_HATE_BLOCK_THRESHOLD = "hate_block_threshold" CONF_SEXUAL_BLOCK_THRESHOLD = "sexual_block_threshold" CONF_DANGEROUS_BLOCK_THRESHOLD = "dangerous_block_threshold" -RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_LOW_AND_ABOVE" +RECOMMENDED_HARM_BLOCK_THRESHOLD = "BLOCK_MEDIUM_AND_ABOVE" diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index ad50c544ac7..f84bd81f80c 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -5,13 +5,14 @@ from __future__ import annotations from typing import Any, Literal import google.ai.generativelanguage as glm -from google.api_core.exceptions import ClientError +from google.api_core.exceptions import GoogleAPICallError import google.generativeai as genai import google.generativeai.types as genai_types import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation +from homeassistant.components.conversation import trace from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant @@ -205,15 +206,6 @@ class GoogleGenerativeAIConversationEntity( messages = [{}, {}] try: - prompt = template.Template( - self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass - ).async_render( - { - "ha_name": self.hass.config.location_name, - }, - parse_result=False, - ) - if llm_api: empty_tool_input = llm.ToolInput( tool_name="", @@ -226,9 +218,24 @@ class GoogleGenerativeAIConversationEntity( device_id=user_input.device_id, ) - prompt = ( - await llm_api.async_get_api_prompt(empty_tool_input) + "\n" + prompt + api_prompt = await llm_api.async_get_api_prompt(empty_tool_input) + + else: + api_prompt = llm.PROMPT_NO_API_CONFIGURED + + prompt = "\n".join( + ( + template.Template( + self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass + ).async_render( + { + "ha_name": self.hass.config.location_name, + }, + parse_result=False, + ), + api_prompt, ) + ) except TemplateError as err: LOGGER.error("Error rendering prompt: %s", err) @@ -244,6 +251,9 @@ class GoogleGenerativeAIConversationEntity( messages[1] = {"role": "model", "parts": "Ok"} LOGGER.debug("Input: '%s' with history: %s", user_input.text, messages) + trace.async_conversation_trace_append( + trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages} + ) chat = model.start_chat(history=messages) chat_request = user_input.text @@ -252,15 +262,25 @@ class GoogleGenerativeAIConversationEntity( try: chat_response = await chat.send_message_async(chat_request) except ( - ClientError, + GoogleAPICallError, ValueError, genai_types.BlockedPromptException, genai_types.StopCandidateException, ) as err: - LOGGER.error("Error sending message: %s", err) + LOGGER.error("Error sending message: %s %s", type(err), err) + + if isinstance( + err, genai_types.StopCandidateException + ) and "finish_reason: SAFETY\n" in str(err): + error = "The message got blocked by your safety settings" + else: + error = ( + f"Sorry, I had a problem talking to Google Generative AI: {err}" + ) + intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem talking to Google Generative AI: {err}", + error, ) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id diff --git a/homeassistant/components/integration/manifest.json b/homeassistant/components/integration/manifest.json index 9e5c597bd1a..029d4740c6f 100644 --- a/homeassistant/components/integration/manifest.json +++ b/homeassistant/components/integration/manifest.json @@ -1,6 +1,6 @@ { "domain": "integration", - "name": "Integration - Riemann sum integral", + "name": "Integral", "after_dependencies": ["counter"], "codeowners": ["@dgomes"], "config_flow": true, diff --git a/homeassistant/components/integration/strings.json b/homeassistant/components/integration/strings.json index 0f5231399b7..ed34b0842d5 100644 --- a/homeassistant/components/integration/strings.json +++ b/homeassistant/components/integration/strings.json @@ -1,5 +1,5 @@ { - "title": "Integration - Riemann sum integral sensor", + "title": "Integral sensor", "config": { "step": { "user": { diff --git a/homeassistant/components/meraki/device_tracker.py b/homeassistant/components/meraki/device_tracker.py index 9f0f4cd4545..a6eefe7345f 100644 --- a/homeassistant/components/meraki/device_tracker.py +++ b/homeassistant/components/meraki/device_tracker.py @@ -21,7 +21,7 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType CONF_VALIDATOR = "validator" CONF_SECRET = "secret" URL = "/api/meraki" -VERSION = "2.0" +ACCEPTED_VERSIONS = ["2.0", "2.1"] _LOGGER = logging.getLogger(__name__) @@ -74,7 +74,7 @@ class MerakiView(HomeAssistantView): if data["secret"] != self.secret: _LOGGER.error("Invalid Secret received from Meraki") return self.json_message("Invalid secret", HTTPStatus.UNPROCESSABLE_ENTITY) - if data["version"] != VERSION: + if data["version"] not in ACCEPTED_VERSIONS: _LOGGER.error("Invalid API version: %s", data["version"]) return self.json_message("Invalid version", HTTPStatus.UNPROCESSABLE_ENTITY) _LOGGER.debug("Valid Secret") diff --git a/homeassistant/components/minecraft_server/manifest.json b/homeassistant/components/minecraft_server/manifest.json index a00936852f0..8e098f98a15 100644 --- a/homeassistant/components/minecraft_server/manifest.json +++ b/homeassistant/components/minecraft_server/manifest.json @@ -6,6 +6,6 @@ "documentation": "https://www.home-assistant.io/integrations/minecraft_server", "iot_class": "local_polling", "loggers": ["dnspython", "mcstatus"], - "quality_scale": "gold", + "quality_scale": "platinum", "requirements": ["mcstatus==11.1.1"] } diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 1e946421bcf..39e2660ca03 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -39,6 +39,7 @@ from .client import ( # noqa: F401 MQTT, async_publish, async_subscribe, + async_subscribe_internal, publish, subscribe, ) @@ -311,7 +312,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: def collect_msg(msg: ReceiveMessage) -> None: messages.append((msg.topic, str(msg.payload).replace("\n", ""))) - unsub = await async_subscribe(hass, call.data["topic"], collect_msg) + unsub = async_subscribe_internal(hass, call.data["topic"], collect_msg) def write_dump() -> None: with open(hass.config.path("mqtt_dump.txt"), "w", encoding="utf8") as fp: @@ -459,7 +460,7 @@ async def websocket_subscribe( # Perform UTF-8 decoding directly in callback routine qos: int = msg.get("qos", DEFAULT_QOS) - connection.subscriptions[msg["id"]] = await async_subscribe( + connection.subscriptions[msg["id"]] = async_subscribe_internal( hass, msg["topic"], forward_messages, encoding=None, qos=qos ) @@ -522,24 +523,13 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: mqtt_client = mqtt_data.client # Unload publish and dump services. - hass.services.async_remove( - DOMAIN, - SERVICE_PUBLISH, - ) - hass.services.async_remove( - DOMAIN, - SERVICE_DUMP, - ) + hass.services.async_remove(DOMAIN, SERVICE_PUBLISH) + hass.services.async_remove(DOMAIN, SERVICE_DUMP) # Stop the discovery await discovery.async_stop(hass) # Unload the platforms - await asyncio.gather( - *( - hass.config_entries.async_forward_entry_unload(entry, component) - for component in mqtt_data.platforms_loaded - ) - ) + await hass.config_entries.async_unload_platforms(entry, mqtt_data.platforms_loaded) mqtt_data.platforms_loaded = set() await asyncio.sleep(0) # Unsubscribe reload dispatchers diff --git a/homeassistant/components/mqtt/alarm_control_panel.py b/homeassistant/components/mqtt/alarm_control_panel.py index e341d54e349..fe6650cbd0f 100644 --- a/homeassistant/components/mqtt/alarm_control_panel.py +++ b/homeassistant/components/mqtt/alarm_control_panel.py @@ -226,7 +226,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_alarm_disarm(self, code: str | None = None) -> None: """Send disarm command. diff --git a/homeassistant/components/mqtt/binary_sensor.py b/homeassistant/components/mqtt/binary_sensor.py index ce772855e78..61e5074378d 100644 --- a/homeassistant/components/mqtt/binary_sensor.py +++ b/homeassistant/components/mqtt/binary_sensor.py @@ -254,7 +254,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity, RestoreEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) @callback def _value_is_expired(self, *_: Any) -> None: diff --git a/homeassistant/components/mqtt/camera.py b/homeassistant/components/mqtt/camera.py index 23457c8d4fc..2c6346f5794 100644 --- a/homeassistant/components/mqtt/camera.py +++ b/homeassistant/components/mqtt/camera.py @@ -3,6 +3,7 @@ from __future__ import annotations from base64 import b64decode +from functools import partial import logging from typing import TYPE_CHECKING @@ -20,7 +21,6 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import subscription from .config import MQTT_BASE_SCHEMA from .const import CONF_QOS, CONF_TOPIC -from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import ReceiveMessage from .schemas import MQTT_ENTITY_COMMON_SCHEMA @@ -97,27 +97,31 @@ class MqttCamera(MqttEntity, Camera): """Return the config schema.""" return DISCOVERY_SCHEMA + @callback + def _image_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages.""" + if CONF_IMAGE_ENCODING in self._config: + self._last_image = b64decode(msg.payload) + else: + if TYPE_CHECKING: + assert isinstance(msg.payload, bytes) + self._last_image = msg.payload + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - @callback - @log_messages(self.hass, self.entity_id) - def message_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages.""" - if CONF_IMAGE_ENCODING in self._config: - self._last_image = b64decode(msg.payload) - else: - if TYPE_CHECKING: - assert isinstance(msg.payload, bytes) - self._last_image = msg.payload - self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, { "state_topic": { "topic": self._config[CONF_TOPIC], - "msg_callback": message_received, + "msg_callback": partial( + self._message_callback, + self._image_received, + None, + ), + "entity_id": self.entity_id, "qos": self._config[CONF_QOS], "encoding": None, } @@ -126,7 +130,7 @@ class MqttCamera(MqttEntity, Camera): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_camera_image( self, width: int | None = None, height: int | None = None diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 857b073a746..16db9a45b58 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -77,7 +77,6 @@ from .const import ( ) from .models import ( DATA_MQTT, - AsyncMessageCallbackType, MessageCallbackType, MqttData, PublishMessage, @@ -184,7 +183,7 @@ async def async_publish( async def async_subscribe( hass: HomeAssistant, topic: str, - msg_callback: AsyncMessageCallbackType | MessageCallbackType, + msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None], qos: int = DEFAULT_QOS, encoding: str | None = DEFAULT_ENCODING, ) -> CALLBACK_TYPE: @@ -192,13 +191,25 @@ async def async_subscribe( Call the return value to unsubscribe. """ - if not mqtt_config_entry_enabled(hass): - raise HomeAssistantError( - f"Cannot subscribe to topic '{topic}', MQTT is not enabled", - translation_key="mqtt_not_setup_cannot_subscribe", - translation_domain=DOMAIN, - translation_placeholders={"topic": topic}, - ) + return async_subscribe_internal(hass, topic, msg_callback, qos, encoding) + + +@callback +def async_subscribe_internal( + hass: HomeAssistant, + topic: str, + msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None], + qos: int = DEFAULT_QOS, + encoding: str | None = DEFAULT_ENCODING, +) -> CALLBACK_TYPE: + """Subscribe to an MQTT topic. + + This function is internal to the MQTT integration + and may change at any time. It should not be considered + a stable API. + + Call the return value to unsubscribe. + """ try: mqtt_data = hass.data[DATA_MQTT] except KeyError as exc: @@ -209,12 +220,15 @@ async def async_subscribe( translation_domain=DOMAIN, translation_placeholders={"topic": topic}, ) from exc - return await mqtt_data.client.async_subscribe( - topic, - msg_callback, - qos, - encoding, - ) + client = mqtt_data.client + if not client.connected and not mqtt_config_entry_enabled(hass): + raise HomeAssistantError( + f"Cannot subscribe to topic '{topic}', MQTT is not enabled", + translation_key="mqtt_not_setup_cannot_subscribe", + translation_domain=DOMAIN, + translation_placeholders={"topic": topic}, + ) + return client.async_subscribe(topic, msg_callback, qos, encoding) @bind_hass @@ -429,10 +443,10 @@ class MQTT: self.config_entry = config_entry self.conf = conf - self._simple_subscriptions: defaultdict[str, list[Subscription]] = defaultdict( - list + self._simple_subscriptions: defaultdict[str, set[Subscription]] = defaultdict( + set ) - self._wildcard_subscriptions: list[Subscription] = [] + self._wildcard_subscriptions: set[Subscription] = set() # _retained_topics prevents a Subscription from receiving a # retained message more than once per topic. This prevents flooding # already active subscribers when new subscribers subscribe to a topic @@ -452,7 +466,7 @@ class MQTT: self._should_reconnect: bool = True self._available_future: asyncio.Future[bool] | None = None - self._max_qos: dict[str, int] = {} # topic, max qos + self._max_qos: defaultdict[str, int] = defaultdict(int) # topic, max qos self._pending_subscriptions: dict[str, int] = {} # topic, qos self._unsubscribe_debouncer = EnsureJobAfterCooldown( UNSUBSCRIBE_COOLDOWN, self._async_perform_unsubscribes @@ -789,9 +803,9 @@ class MQTT: The caller is responsible clearing the cache of _matching_subscriptions. """ if subscription.is_simple_match: - self._simple_subscriptions[subscription.topic].append(subscription) + self._simple_subscriptions[subscription.topic].add(subscription) else: - self._wildcard_subscriptions.append(subscription) + self._wildcard_subscriptions.add(subscription) @callback def _async_untrack_subscription(self, subscription: Subscription) -> None: @@ -820,8 +834,8 @@ class MQTT: """Queue requested subscriptions.""" for subscription in subscriptions: topic, qos = subscription - max_qos = max(qos, self._max_qos.setdefault(topic, qos)) - self._max_qos[topic] = max_qos + if (max_qos := self._max_qos[topic]) < qos: + self._max_qos[topic] = (max_qos := qos) self._pending_subscriptions[topic] = max_qos # Cancel any pending unsubscribe since we are subscribing now if topic in self._pending_unsubscribes: @@ -832,26 +846,29 @@ class MQTT: def _exception_message( self, - msg_callback: AsyncMessageCallbackType | MessageCallbackType, + msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None], msg: ReceiveMessage, ) -> str: """Return a string with the exception message.""" + # if msg_callback is a partial we return the name of the first argument + if isinstance(msg_callback, partial): + call_back_name = getattr(msg_callback.args[0], "__name__") # type: ignore[unreachable] + else: + call_back_name = getattr(msg_callback, "__name__") return ( - f"Exception in {msg_callback.__name__} when handling msg on " + f"Exception in {call_back_name} when handling msg on " f"'{msg.topic}': '{msg.payload}'" # type: ignore[str-bytes-safe] ) - async def async_subscribe( + @callback + def async_subscribe( self, topic: str, - msg_callback: AsyncMessageCallbackType | MessageCallbackType, + msg_callback: Callable[[ReceiveMessage], Coroutine[Any, Any, None] | None], qos: int, encoding: str | None = None, ) -> Callable[[], None]: - """Set up a subscription to a topic with the provided qos. - - This method is a coroutine. - """ + """Set up a subscription to a topic with the provided qos.""" if not isinstance(topic, str): raise HomeAssistantError("Topic needs to be a string!") @@ -877,18 +894,18 @@ class MQTT: if self.connected: self._async_queue_subscriptions(((topic, qos),)) - @callback - def async_remove() -> None: - """Remove subscription.""" - self._async_untrack_subscription(subscription) - self._matching_subscriptions.cache_clear() - if subscription in self._retained_topics: - del self._retained_topics[subscription] - # Only unsubscribe if currently connected - if self.connected: - self._async_unsubscribe(topic) + return partial(self._async_remove, subscription) - return async_remove + @callback + def _async_remove(self, subscription: Subscription) -> None: + """Remove subscription.""" + self._async_untrack_subscription(subscription) + self._matching_subscriptions.cache_clear() + if subscription in self._retained_topics: + del self._retained_topics[subscription] + # Only unsubscribe if currently connected + if self.connected: + self._async_unsubscribe(subscription.topic) @callback def _async_unsubscribe(self, topic: str) -> None: @@ -1257,9 +1274,7 @@ class MQTT: last_discovery = self._mqtt_data.last_discovery last_subscribe = now if self._pending_subscriptions else self._last_subscribe - wait_until = max( - last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN - ) + wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN while now < wait_until: await asyncio.sleep(wait_until - now) now = time.monotonic() @@ -1267,9 +1282,7 @@ class MQTT: last_subscribe = ( now if self._pending_subscriptions else self._last_subscribe ) - wait_until = max( - last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN - ) + wait_until = max(last_discovery, last_subscribe) + DISCOVERY_COOLDOWN def _matcher_for_topic(subscription: str) -> Callable[[str], bool]: diff --git a/homeassistant/components/mqtt/climate.py b/homeassistant/components/mqtt/climate.py index b09ee17af68..57f71008ecc 100644 --- a/homeassistant/components/mqtt/climate.py +++ b/homeassistant/components/mqtt/climate.py @@ -511,7 +511,7 @@ class MqttTemperatureControlEntity(MqttEntity, ABC): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def _publish(self, topic: str, payload: PublishPayloadType) -> None: if self._topic[topic] is not None: diff --git a/homeassistant/components/mqtt/cover.py b/homeassistant/components/mqtt/cover.py index d741f602670..a4c7c1d8b3b 100644 --- a/homeassistant/components/mqtt/cover.py +++ b/homeassistant/components/mqtt/cover.py @@ -512,7 +512,7 @@ class MqttCover(MqttEntity, CoverEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_open_cover(self, **kwargs: Any) -> None: """Move the cover up. diff --git a/homeassistant/components/mqtt/device_tracker.py b/homeassistant/components/mqtt/device_tracker.py index 519af19ac16..87abba2ac95 100644 --- a/homeassistant/components/mqtt/device_tracker.py +++ b/homeassistant/components/mqtt/device_tracker.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial import logging from typing import TYPE_CHECKING @@ -32,13 +33,7 @@ from homeassistant.helpers.typing import ConfigType from . import subscription from .config import MQTT_BASE_SCHEMA from .const import CONF_PAYLOAD_RESET, CONF_QOS, CONF_STATE_TOPIC -from .debug_info import log_messages -from .mixins import ( - CONF_JSON_ATTRS_TOPIC, - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import CONF_JSON_ATTRS_TOPIC, MqttEntity, async_setup_entity_entry_helper from .models import MqttValueTemplate, ReceiveMessage, ReceivePayloadType from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .util import valid_subscribe_topic @@ -119,33 +114,31 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity): config.get(CONF_VALUE_TEMPLATE), entity=self ).async_render_with_possible_json_value + @callback + def _tracker_message_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages.""" + payload = self._value_template(msg.payload) + if not payload.strip(): # No output from template, ignore + _LOGGER.debug( + "Ignoring empty payload '%s' after rendering for topic %s", + payload, + msg.topic, + ) + return + if payload == self._config[CONF_PAYLOAD_HOME]: + self._location_name = STATE_HOME + elif payload == self._config[CONF_PAYLOAD_NOT_HOME]: + self._location_name = STATE_NOT_HOME + elif payload == self._config[CONF_PAYLOAD_RESET]: + self._location_name = None + else: + if TYPE_CHECKING: + assert isinstance(msg.payload, str) + self._location_name = msg.payload + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_location_name"}) - def message_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages.""" - payload = self._value_template(msg.payload) - if not payload.strip(): # No output from template, ignore - _LOGGER.debug( - "Ignoring empty payload '%s' after rendering for topic %s", - payload, - msg.topic, - ) - return - if payload == self._config[CONF_PAYLOAD_HOME]: - self._location_name = STATE_HOME - elif payload == self._config[CONF_PAYLOAD_NOT_HOME]: - self._location_name = STATE_NOT_HOME - elif payload == self._config[CONF_PAYLOAD_RESET]: - self._location_name = None - else: - if TYPE_CHECKING: - assert isinstance(msg.payload, str) - self._location_name = msg.payload - state_topic: str | None = self._config.get(CONF_STATE_TOPIC) if state_topic is None: return @@ -155,7 +148,12 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity): { "state_topic": { "topic": state_topic, - "msg_callback": message_received, + "msg_callback": partial( + self._message_callback, + self._tracker_message_received, + {"_location_name"}, + ), + "entity_id": self.entity_id, "qos": self._config[CONF_QOS], } }, @@ -168,7 +166,7 @@ class MqttDeviceTracker(MqttEntity, TrackerEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) @property def latitude(self) -> float | None: diff --git a/homeassistant/components/mqtt/event.py b/homeassistant/components/mqtt/event.py index 5c8ae7f7be1..a09579fccef 100644 --- a/homeassistant/components/mqtt/event.py +++ b/homeassistant/components/mqtt/event.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial import logging from typing import Any @@ -31,7 +32,6 @@ from .const import ( PAYLOAD_EMPTY_JSON, PAYLOAD_NONE, ) -from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import ( DATA_MQTT, @@ -113,90 +113,91 @@ class MqttEvent(MqttEntity, EventEntity): self._config.get(CONF_VALUE_TEMPLATE), entity=self ).async_render_with_possible_json_value + @callback + def _event_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages.""" + if msg.retain: + _LOGGER.debug( + "Ignoring event trigger from replayed retained payload '%s' on topic %s", + msg.payload, + msg.topic, + ) + return + event_attributes: dict[str, Any] = {} + event_type: str + try: + payload = self._template(msg.payload, PayloadSentinel.DEFAULT) + except MqttValueTemplateException as exc: + _LOGGER.warning(exc) + return + if ( + not payload + or payload is PayloadSentinel.DEFAULT + or payload in (PAYLOAD_NONE, PAYLOAD_EMPTY_JSON) + ): + _LOGGER.debug( + "Ignoring empty payload '%s' after rendering for topic %s", + payload, + msg.topic, + ) + return + try: + event_attributes = json_loads_object(payload) + event_type = str(event_attributes.pop(event.ATTR_EVENT_TYPE)) + _LOGGER.debug( + ( + "JSON event data detected after processing payload '%s' on" + " topic %s, type %s, attributes %s" + ), + payload, + msg.topic, + event_type, + event_attributes, + ) + except KeyError: + _LOGGER.warning( + ("`event_type` missing in JSON event payload, " " '%s' on topic %s"), + payload, + msg.topic, + ) + return + except JSON_DECODE_EXCEPTIONS: + _LOGGER.warning( + ( + "No valid JSON event payload detected, " + "value after processing payload" + " '%s' on topic %s" + ), + payload, + msg.topic, + ) + return + try: + self._trigger_event(event_type, event_attributes) + except ValueError: + _LOGGER.warning( + "Invalid event type %s for %s received on topic %s, payload %s", + event_type, + self.entity_id, + msg.topic, + payload, + ) + return + mqtt_data = self.hass.data[DATA_MQTT] + mqtt_data.state_write_requests.write_state_request(self) + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" topics: dict[str, dict[str, Any]] = {} - @callback - @log_messages(self.hass, self.entity_id) - def message_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages.""" - if msg.retain: - _LOGGER.debug( - "Ignoring event trigger from replayed retained payload '%s' on topic %s", - msg.payload, - msg.topic, - ) - return - event_attributes: dict[str, Any] = {} - event_type: str - try: - payload = self._template(msg.payload, PayloadSentinel.DEFAULT) - except MqttValueTemplateException as exc: - _LOGGER.warning(exc) - return - if ( - not payload - or payload is PayloadSentinel.DEFAULT - or payload in (PAYLOAD_NONE, PAYLOAD_EMPTY_JSON) - ): - _LOGGER.debug( - "Ignoring empty payload '%s' after rendering for topic %s", - payload, - msg.topic, - ) - return - try: - event_attributes = json_loads_object(payload) - event_type = str(event_attributes.pop(event.ATTR_EVENT_TYPE)) - _LOGGER.debug( - ( - "JSON event data detected after processing payload '%s' on" - " topic %s, type %s, attributes %s" - ), - payload, - msg.topic, - event_type, - event_attributes, - ) - except KeyError: - _LOGGER.warning( - ( - "`event_type` missing in JSON event payload, " - " '%s' on topic %s" - ), - payload, - msg.topic, - ) - return - except JSON_DECODE_EXCEPTIONS: - _LOGGER.warning( - ( - "No valid JSON event payload detected, " - "value after processing payload" - " '%s' on topic %s" - ), - payload, - msg.topic, - ) - return - try: - self._trigger_event(event_type, event_attributes) - except ValueError: - _LOGGER.warning( - "Invalid event type %s for %s received on topic %s, payload %s", - event_type, - self.entity_id, - msg.topic, - payload, - ) - return - mqtt_data = self.hass.data[DATA_MQTT] - mqtt_data.state_write_requests.write_state_request(self) - topics["state_topic"] = { "topic": self._config[CONF_STATE_TOPIC], - "msg_callback": message_received, + "msg_callback": partial( + self._message_callback, + self._event_received, + None, + ), + "entity_id": self.entity_id, "qos": self._config[CONF_QOS], "encoding": self._config[CONF_ENCODING] or None, } @@ -207,4 +208,4 @@ class MqttEvent(MqttEntity, EventEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) diff --git a/homeassistant/components/mqtt/fan.py b/homeassistant/components/mqtt/fan.py index 10571043fb8..a418131d5c5 100644 --- a/homeassistant/components/mqtt/fan.py +++ b/homeassistant/components/mqtt/fan.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial import logging import math from typing import Any @@ -49,12 +50,7 @@ from .const import ( CONF_STATE_VALUE_TEMPLATE, PAYLOAD_NONE, ) -from .debug_info import log_messages -from .mixins import ( - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import ( MessageCallbackType, MqttCommandTemplate, @@ -338,137 +334,142 @@ class MqttFan(MqttEntity, FanEntity): for key, tpl in value_templates.items() } + @callback + def _state_received(self, msg: ReceiveMessage) -> None: + """Handle new received MQTT message.""" + payload = self._value_templates[CONF_STATE](msg.payload) + if not payload: + _LOGGER.debug("Ignoring empty state from '%s'", msg.topic) + return + if payload == self._payload["STATE_ON"]: + self._attr_is_on = True + elif payload == self._payload["STATE_OFF"]: + self._attr_is_on = False + elif payload == PAYLOAD_NONE: + self._attr_is_on = None + + @callback + def _percentage_received(self, msg: ReceiveMessage) -> None: + """Handle new received MQTT message for the percentage.""" + rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE]( + msg.payload + ) + if not rendered_percentage_payload: + _LOGGER.debug("Ignoring empty speed from '%s'", msg.topic) + return + if rendered_percentage_payload == self._payload["PERCENTAGE_RESET"]: + self._attr_percentage = None + return + try: + percentage = ranged_value_to_percentage( + self._speed_range, int(rendered_percentage_payload) + ) + except ValueError: + _LOGGER.warning( + ( + "'%s' received on topic %s. '%s' is not a valid speed within" + " the speed range" + ), + msg.payload, + msg.topic, + rendered_percentage_payload, + ) + return + if percentage < 0 or percentage > 100: + _LOGGER.warning( + ( + "'%s' received on topic %s. '%s' is not a valid speed within" + " the speed range" + ), + msg.payload, + msg.topic, + rendered_percentage_payload, + ) + return + self._attr_percentage = percentage + + @callback + def _preset_mode_received(self, msg: ReceiveMessage) -> None: + """Handle new received MQTT message for preset mode.""" + preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload)) + if preset_mode == self._payload["PRESET_MODE_RESET"]: + self._attr_preset_mode = None + return + if not preset_mode: + _LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic) + return + if not self.preset_modes or preset_mode not in self.preset_modes: + _LOGGER.warning( + "'%s' received on topic %s. '%s' is not a valid preset mode", + msg.payload, + msg.topic, + preset_mode, + ) + return + + self._attr_preset_mode = preset_mode + + @callback + def _oscillation_received(self, msg: ReceiveMessage) -> None: + """Handle new received MQTT message for the oscillation.""" + payload = self._value_templates[ATTR_OSCILLATING](msg.payload) + if not payload: + _LOGGER.debug("Ignoring empty oscillation from '%s'", msg.topic) + return + if payload == self._payload["OSCILLATE_ON_PAYLOAD"]: + self._attr_oscillating = True + elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]: + self._attr_oscillating = False + + @callback + def _direction_received(self, msg: ReceiveMessage) -> None: + """Handle new received MQTT message for the direction.""" + direction = self._value_templates[ATTR_DIRECTION](msg.payload) + if not direction: + _LOGGER.debug("Ignoring empty direction from '%s'", msg.topic) + return + self._attr_current_direction = str(direction) + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" topics: dict[str, Any] = {} - def add_subscribe_topic(topic: str, msg_callback: MessageCallbackType) -> bool: + def add_subscribe_topic( + topic: str, msg_callback: MessageCallbackType, tracked_attributes: set[str] + ) -> bool: """Add a topic to subscribe to.""" if has_topic := self._topic[topic] is not None: topics[topic] = { "topic": self._topic[topic], - "msg_callback": msg_callback, + "msg_callback": partial( + self._message_callback, msg_callback, tracked_attributes + ), + "entity_id": self.entity_id, "qos": self._config[CONF_QOS], "encoding": self._config[CONF_ENCODING] or None, } return has_topic - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_is_on"}) - def state_received(msg: ReceiveMessage) -> None: - """Handle new received MQTT message.""" - payload = self._value_templates[CONF_STATE](msg.payload) - if not payload: - _LOGGER.debug("Ignoring empty state from '%s'", msg.topic) - return - if payload == self._payload["STATE_ON"]: - self._attr_is_on = True - elif payload == self._payload["STATE_OFF"]: - self._attr_is_on = False - elif payload == PAYLOAD_NONE: - self._attr_is_on = None - - add_subscribe_topic(CONF_STATE_TOPIC, state_received) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_percentage"}) - def percentage_received(msg: ReceiveMessage) -> None: - """Handle new received MQTT message for the percentage.""" - rendered_percentage_payload = self._value_templates[ATTR_PERCENTAGE]( - msg.payload - ) - if not rendered_percentage_payload: - _LOGGER.debug("Ignoring empty speed from '%s'", msg.topic) - return - if rendered_percentage_payload == self._payload["PERCENTAGE_RESET"]: - self._attr_percentage = None - return - try: - percentage = ranged_value_to_percentage( - self._speed_range, int(rendered_percentage_payload) - ) - except ValueError: - _LOGGER.warning( - ( - "'%s' received on topic %s. '%s' is not a valid speed within" - " the speed range" - ), - msg.payload, - msg.topic, - rendered_percentage_payload, - ) - return - if percentage < 0 or percentage > 100: - _LOGGER.warning( - ( - "'%s' received on topic %s. '%s' is not a valid speed within" - " the speed range" - ), - msg.payload, - msg.topic, - rendered_percentage_payload, - ) - return - self._attr_percentage = percentage - - add_subscribe_topic(CONF_PERCENTAGE_STATE_TOPIC, percentage_received) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_preset_mode"}) - def preset_mode_received(msg: ReceiveMessage) -> None: - """Handle new received MQTT message for preset mode.""" - preset_mode = str(self._value_templates[ATTR_PRESET_MODE](msg.payload)) - if preset_mode == self._payload["PRESET_MODE_RESET"]: - self._attr_preset_mode = None - return - if not preset_mode: - _LOGGER.debug("Ignoring empty preset_mode from '%s'", msg.topic) - return - if not self.preset_modes or preset_mode not in self.preset_modes: - _LOGGER.warning( - "'%s' received on topic %s. '%s' is not a valid preset mode", - msg.payload, - msg.topic, - preset_mode, - ) - return - - self._attr_preset_mode = preset_mode - - add_subscribe_topic(CONF_PRESET_MODE_STATE_TOPIC, preset_mode_received) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_oscillating"}) - def oscillation_received(msg: ReceiveMessage) -> None: - """Handle new received MQTT message for the oscillation.""" - payload = self._value_templates[ATTR_OSCILLATING](msg.payload) - if not payload: - _LOGGER.debug("Ignoring empty oscillation from '%s'", msg.topic) - return - if payload == self._payload["OSCILLATE_ON_PAYLOAD"]: - self._attr_oscillating = True - elif payload == self._payload["OSCILLATE_OFF_PAYLOAD"]: - self._attr_oscillating = False - - if add_subscribe_topic(CONF_OSCILLATION_STATE_TOPIC, oscillation_received): + add_subscribe_topic(CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"}) + add_subscribe_topic( + CONF_PERCENTAGE_STATE_TOPIC, self._percentage_received, {"_attr_percentage"} + ) + add_subscribe_topic( + CONF_PRESET_MODE_STATE_TOPIC, + self._preset_mode_received, + {"_attr_preset_mode"}, + ) + if add_subscribe_topic( + CONF_OSCILLATION_STATE_TOPIC, + self._oscillation_received, + {"_attr_oscillating"}, + ): self._attr_oscillating = False - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_current_direction"}) - def direction_received(msg: ReceiveMessage) -> None: - """Handle new received MQTT message for the direction.""" - direction = self._value_templates[ATTR_DIRECTION](msg.payload) - if not direction: - _LOGGER.debug("Ignoring empty direction from '%s'", msg.topic) - return - self._attr_current_direction = str(direction) - - add_subscribe_topic(CONF_DIRECTION_STATE_TOPIC, direction_received) + add_subscribe_topic( + CONF_DIRECTION_STATE_TOPIC, + self._direction_received, + {"_attr_current_direction"}, + ) self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, topics @@ -476,7 +477,7 @@ class MqttFan(MqttEntity, FanEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) @property def is_on(self) -> bool | None: diff --git a/homeassistant/components/mqtt/humidifier.py b/homeassistant/components/mqtt/humidifier.py index b9f57dfe0ef..097018f008f 100644 --- a/homeassistant/components/mqtt/humidifier.py +++ b/homeassistant/components/mqtt/humidifier.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial import logging from typing import Any @@ -51,12 +52,7 @@ from .const import ( CONF_STATE_VALUE_TEMPLATE, PAYLOAD_NONE, ) -from .debug_info import log_messages -from .mixins import ( - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import ( MqttCommandTemplate, MqttValueTemplate, @@ -284,164 +280,166 @@ class MqttHumidifier(MqttEntity, HumidifierEntity): topics: dict[str, dict[str, Any]], topic: str, msg_callback: Callable[[ReceiveMessage], None], + tracked_attributes: set[str], ) -> None: """Add a subscription.""" qos: int = self._config[CONF_QOS] if topic in self._topic and self._topic[topic] is not None: topics[topic] = { "topic": self._topic[topic], - "msg_callback": msg_callback, + "msg_callback": partial( + self._message_callback, msg_callback, tracked_attributes + ), + "entity_id": self.entity_id, "qos": qos, "encoding": self._config[CONF_ENCODING] or None, } + @callback + def _state_received(self, msg: ReceiveMessage) -> None: + """Handle new received MQTT message.""" + payload = self._value_templates[CONF_STATE](msg.payload) + if not payload: + _LOGGER.debug("Ignoring empty state from '%s'", msg.topic) + return + if payload == self._payload["STATE_ON"]: + self._attr_is_on = True + elif payload == self._payload["STATE_OFF"]: + self._attr_is_on = False + elif payload == PAYLOAD_NONE: + self._attr_is_on = None + + @callback + def _action_received(self, msg: ReceiveMessage) -> None: + """Handle new received MQTT message.""" + action_payload = self._value_templates[ATTR_ACTION](msg.payload) + if not action_payload or action_payload == PAYLOAD_NONE: + _LOGGER.debug("Ignoring empty action from '%s'", msg.topic) + return + try: + self._attr_action = HumidifierAction(str(action_payload)) + except ValueError: + _LOGGER.error( + "'%s' received on topic %s. '%s' is not a valid action", + msg.payload, + msg.topic, + action_payload, + ) + return + + @callback + def _current_humidity_received(self, msg: ReceiveMessage) -> None: + """Handle new received MQTT message for the current humidity.""" + rendered_current_humidity_payload = self._value_templates[ + ATTR_CURRENT_HUMIDITY + ](msg.payload) + if rendered_current_humidity_payload == self._payload["HUMIDITY_RESET"]: + self._attr_current_humidity = None + return + if not rendered_current_humidity_payload: + _LOGGER.debug("Ignoring empty current humidity from '%s'", msg.topic) + return + try: + current_humidity = round(float(rendered_current_humidity_payload)) + except ValueError: + _LOGGER.warning( + "'%s' received on topic %s. '%s' is not a valid humidity", + msg.payload, + msg.topic, + rendered_current_humidity_payload, + ) + return + if current_humidity < 0 or current_humidity > 100: + _LOGGER.warning( + "'%s' received on topic %s. '%s' is not a valid humidity", + msg.payload, + msg.topic, + rendered_current_humidity_payload, + ) + return + self._attr_current_humidity = current_humidity + + @callback + def _target_humidity_received(self, msg: ReceiveMessage) -> None: + """Handle new received MQTT message for the target humidity.""" + rendered_target_humidity_payload = self._value_templates[ATTR_HUMIDITY]( + msg.payload + ) + if not rendered_target_humidity_payload: + _LOGGER.debug("Ignoring empty target humidity from '%s'", msg.topic) + return + if rendered_target_humidity_payload == self._payload["HUMIDITY_RESET"]: + self._attr_target_humidity = None + return + try: + target_humidity = round(float(rendered_target_humidity_payload)) + except ValueError: + _LOGGER.warning( + "'%s' received on topic %s. '%s' is not a valid target humidity", + msg.payload, + msg.topic, + rendered_target_humidity_payload, + ) + return + if ( + target_humidity < self._attr_min_humidity + or target_humidity > self._attr_max_humidity + ): + _LOGGER.warning( + "'%s' received on topic %s. '%s' is not a valid target humidity", + msg.payload, + msg.topic, + rendered_target_humidity_payload, + ) + return + self._attr_target_humidity = target_humidity + + @callback + def _mode_received(self, msg: ReceiveMessage) -> None: + """Handle new received MQTT message for mode.""" + mode = str(self._value_templates[ATTR_MODE](msg.payload)) + if mode == self._payload["MODE_RESET"]: + self._attr_mode = None + return + if not mode: + _LOGGER.debug("Ignoring empty mode from '%s'", msg.topic) + return + if not self.available_modes or mode not in self.available_modes: + _LOGGER.warning( + "'%s' received on topic %s. '%s' is not a valid mode", + msg.payload, + msg.topic, + mode, + ) + return + + self._attr_mode = mode + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" topics: dict[str, Any] = {} - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_is_on"}) - def state_received(msg: ReceiveMessage) -> None: - """Handle new received MQTT message.""" - payload = self._value_templates[CONF_STATE](msg.payload) - if not payload: - _LOGGER.debug("Ignoring empty state from '%s'", msg.topic) - return - if payload == self._payload["STATE_ON"]: - self._attr_is_on = True - elif payload == self._payload["STATE_OFF"]: - self._attr_is_on = False - elif payload == PAYLOAD_NONE: - self._attr_is_on = None - - self.add_subscription(topics, CONF_STATE_TOPIC, state_received) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_action"}) - def action_received(msg: ReceiveMessage) -> None: - """Handle new received MQTT message.""" - action_payload = self._value_templates[ATTR_ACTION](msg.payload) - if not action_payload or action_payload == PAYLOAD_NONE: - _LOGGER.debug("Ignoring empty action from '%s'", msg.topic) - return - try: - self._attr_action = HumidifierAction(str(action_payload)) - except ValueError: - _LOGGER.error( - "'%s' received on topic %s. '%s' is not a valid action", - msg.payload, - msg.topic, - action_payload, - ) - return - - self.add_subscription(topics, CONF_ACTION_TOPIC, action_received) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_current_humidity"}) - def current_humidity_received(msg: ReceiveMessage) -> None: - """Handle new received MQTT message for the current humidity.""" - rendered_current_humidity_payload = self._value_templates[ - ATTR_CURRENT_HUMIDITY - ](msg.payload) - if rendered_current_humidity_payload == self._payload["HUMIDITY_RESET"]: - self._attr_current_humidity = None - return - if not rendered_current_humidity_payload: - _LOGGER.debug("Ignoring empty current humidity from '%s'", msg.topic) - return - try: - current_humidity = round(float(rendered_current_humidity_payload)) - except ValueError: - _LOGGER.warning( - "'%s' received on topic %s. '%s' is not a valid humidity", - msg.payload, - msg.topic, - rendered_current_humidity_payload, - ) - return - if current_humidity < 0 or current_humidity > 100: - _LOGGER.warning( - "'%s' received on topic %s. '%s' is not a valid humidity", - msg.payload, - msg.topic, - rendered_current_humidity_payload, - ) - return - self._attr_current_humidity = current_humidity - self.add_subscription( - topics, CONF_CURRENT_HUMIDITY_TOPIC, current_humidity_received + topics, CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"} ) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_target_humidity"}) - def target_humidity_received(msg: ReceiveMessage) -> None: - """Handle new received MQTT message for the target humidity.""" - rendered_target_humidity_payload = self._value_templates[ATTR_HUMIDITY]( - msg.payload - ) - if not rendered_target_humidity_payload: - _LOGGER.debug("Ignoring empty target humidity from '%s'", msg.topic) - return - if rendered_target_humidity_payload == self._payload["HUMIDITY_RESET"]: - self._attr_target_humidity = None - return - try: - target_humidity = round(float(rendered_target_humidity_payload)) - except ValueError: - _LOGGER.warning( - "'%s' received on topic %s. '%s' is not a valid target humidity", - msg.payload, - msg.topic, - rendered_target_humidity_payload, - ) - return - if ( - target_humidity < self._attr_min_humidity - or target_humidity > self._attr_max_humidity - ): - _LOGGER.warning( - "'%s' received on topic %s. '%s' is not a valid target humidity", - msg.payload, - msg.topic, - rendered_target_humidity_payload, - ) - return - self._attr_target_humidity = target_humidity - self.add_subscription( - topics, CONF_TARGET_HUMIDITY_STATE_TOPIC, target_humidity_received + topics, CONF_ACTION_TOPIC, self._action_received, {"_attr_action"} + ) + self.add_subscription( + topics, + CONF_CURRENT_HUMIDITY_TOPIC, + self._current_humidity_received, + {"_attr_current_humidity"}, + ) + self.add_subscription( + topics, + CONF_TARGET_HUMIDITY_STATE_TOPIC, + self._target_humidity_received, + {"_attr_target_humidity"}, + ) + self.add_subscription( + topics, CONF_MODE_STATE_TOPIC, self._mode_received, {"_attr_mode"} ) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_mode"}) - def mode_received(msg: ReceiveMessage) -> None: - """Handle new received MQTT message for mode.""" - mode = str(self._value_templates[ATTR_MODE](msg.payload)) - if mode == self._payload["MODE_RESET"]: - self._attr_mode = None - return - if not mode: - _LOGGER.debug("Ignoring empty mode from '%s'", msg.topic) - return - if not self.available_modes or mode not in self.available_modes: - _LOGGER.warning( - "'%s' received on topic %s. '%s' is not a valid mode", - msg.payload, - msg.topic, - mode, - ) - return - - self._attr_mode = mode - - self.add_subscription(topics, CONF_MODE_STATE_TOPIC, mode_received) self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, topics @@ -449,7 +447,7 @@ class MqttHumidifier(MqttEntity, HumidifierEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_turn_on(self, **kwargs: Any) -> None: """Turn on the entity. diff --git a/homeassistant/components/mqtt/image.py b/homeassistant/components/mqtt/image.py index eec289aa464..4fa410c4595 100644 --- a/homeassistant/components/mqtt/image.py +++ b/homeassistant/components/mqtt/image.py @@ -5,6 +5,7 @@ from __future__ import annotations from base64 import b64decode import binascii from collections.abc import Callable +from functools import partial import logging from typing import TYPE_CHECKING, Any @@ -26,7 +27,6 @@ from homeassistant.util import dt as dt_util from . import subscription from .config import MQTT_BASE_SCHEMA from .const import CONF_ENCODING, CONF_QOS -from .debug_info import log_messages from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import ( DATA_MQTT, @@ -143,6 +143,45 @@ class MqttImage(MqttEntity, ImageEntity): config.get(CONF_URL_TEMPLATE), entity=self ).async_render_with_possible_json_value + @callback + def _image_data_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages.""" + try: + if CONF_IMAGE_ENCODING in self._config: + self._last_image = b64decode(msg.payload) + else: + if TYPE_CHECKING: + assert isinstance(msg.payload, bytes) + self._last_image = msg.payload + except (binascii.Error, ValueError, AssertionError) as err: + _LOGGER.error( + "Error processing image data received at topic %s: %s", + msg.topic, + err, + ) + self._last_image = None + self._attr_image_last_updated = dt_util.utcnow() + self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self) + + @callback + def _image_from_url_request_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages.""" + try: + url = cv.url(self._url_template(msg.payload)) + self._attr_image_url = url + except MqttValueTemplateException as exc: + _LOGGER.warning(exc) + return + except vol.Invalid: + _LOGGER.error( + "Invalid image URL '%s' received at topic %s", + msg.payload, + msg.topic, + ) + self._attr_image_last_updated = dt_util.utcnow() + self._cached_image = None + self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self) + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" @@ -159,56 +198,15 @@ class MqttImage(MqttEntity, ImageEntity): if has_topic := self._topic[topic] is not None: topics[topic] = { "topic": self._topic[topic], - "msg_callback": msg_callback, + "msg_callback": partial(self._message_callback, msg_callback, None), + "entity_id": self.entity_id, "qos": self._config[CONF_QOS], "encoding": encoding, } return has_topic - @callback - @log_messages(self.hass, self.entity_id) - def image_data_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages.""" - try: - if CONF_IMAGE_ENCODING in self._config: - self._last_image = b64decode(msg.payload) - else: - if TYPE_CHECKING: - assert isinstance(msg.payload, bytes) - self._last_image = msg.payload - except (binascii.Error, ValueError, AssertionError) as err: - _LOGGER.error( - "Error processing image data received at topic %s: %s", - msg.topic, - err, - ) - self._last_image = None - self._attr_image_last_updated = dt_util.utcnow() - self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self) - - add_subscribe_topic(CONF_IMAGE_TOPIC, image_data_received) - - @callback - @log_messages(self.hass, self.entity_id) - def image_from_url_request_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages.""" - try: - url = cv.url(self._url_template(msg.payload)) - self._attr_image_url = url - except MqttValueTemplateException as exc: - _LOGGER.warning(exc) - return - except vol.Invalid: - _LOGGER.error( - "Invalid image URL '%s' received at topic %s", - msg.payload, - msg.topic, - ) - self._attr_image_last_updated = dt_util.utcnow() - self._cached_image = None - self.hass.data[DATA_MQTT].state_write_requests.write_state_request(self) - - add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received) + add_subscribe_topic(CONF_IMAGE_TOPIC, self._image_data_received) + add_subscribe_topic(CONF_URL_TOPIC, self._image_from_url_request_received) self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, topics @@ -216,7 +214,7 @@ class MqttImage(MqttEntity, ImageEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_image(self) -> bytes | None: """Return bytes of image.""" diff --git a/homeassistant/components/mqtt/lawn_mower.py b/homeassistant/components/mqtt/lawn_mower.py index 7380f478e2c..2452b511144 100644 --- a/homeassistant/components/mqtt/lawn_mower.py +++ b/homeassistant/components/mqtt/lawn_mower.py @@ -4,6 +4,7 @@ from __future__ import annotations from collections.abc import Callable import contextlib +from functools import partial import logging import voluptuous as vol @@ -31,12 +32,7 @@ from .const import ( DEFAULT_OPTIMISTIC, DEFAULT_RETAIN, ) -from .debug_info import log_messages -from .mixins import ( - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import ( MqttCommandTemplate, MqttValueTemplate, @@ -150,57 +146,59 @@ class MqttLawnMower(MqttEntity, LawnMowerEntity, RestoreEntity): config.get(CONF_START_MOWING_COMMAND_TEMPLATE), entity=self ).async_render + @callback + def _message_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages.""" + payload = str(self._value_template(msg.payload)) + if not payload: + _LOGGER.debug( + "Invalid empty activity payload from topic %s, for entity %s", + msg.topic, + self.entity_id, + ) + return + if payload.lower() == "none": + self._attr_activity = None + return + + try: + self._attr_activity = LawnMowerActivity(payload) + except ValueError: + _LOGGER.error( + "Invalid activity for %s: '%s' (valid activities: %s)", + self.entity_id, + payload, + [option.value for option in LawnMowerActivity], + ) + return + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_activity"}) - def message_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages.""" - payload = str(self._value_template(msg.payload)) - if not payload: - _LOGGER.debug( - "Invalid empty activity payload from topic %s, for entity %s", - msg.topic, - self.entity_id, - ) - return - if payload.lower() == "none": - self._attr_activity = None - return - - try: - self._attr_activity = LawnMowerActivity(payload) - except ValueError: - _LOGGER.error( - "Invalid activity for %s: '%s' (valid activities: %s)", - self.entity_id, - payload, - [option.value for option in LawnMowerActivity], - ) - return - if self._config.get(CONF_ACTIVITY_STATE_TOPIC) is None: # Force into optimistic mode. self._attr_assumed_state = True - else: - self._sub_state = subscription.async_prepare_subscribe_topics( - self.hass, - self._sub_state, - { - CONF_ACTIVITY_STATE_TOPIC: { - "topic": self._config.get(CONF_ACTIVITY_STATE_TOPIC), - "msg_callback": message_received, - "qos": self._config[CONF_QOS], - "encoding": self._config[CONF_ENCODING] or None, - } - }, - ) + return + self._sub_state = subscription.async_prepare_subscribe_topics( + self.hass, + self._sub_state, + { + CONF_ACTIVITY_STATE_TOPIC: { + "topic": self._config.get(CONF_ACTIVITY_STATE_TOPIC), + "msg_callback": partial( + self._message_callback, + self._message_received, + {"_attr_activity"}, + ), + "entity_id": self.entity_id, + "qos": self._config[CONF_QOS], + "encoding": self._config[CONF_ENCODING] or None, + } + }, + ) async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) if self._attr_assumed_state and ( last_state := await self.async_get_last_state() diff --git a/homeassistant/components/mqtt/light/schema_basic.py b/homeassistant/components/mqtt/light/schema_basic.py index 904e45b3d2f..583374c8d20 100644 --- a/homeassistant/components/mqtt/light/schema_basic.py +++ b/homeassistant/components/mqtt/light/schema_basic.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial import logging from typing import Any, cast @@ -53,8 +54,7 @@ from ..const import ( CONF_STATE_VALUE_TEMPLATE, PAYLOAD_NONE, ) -from ..debug_info import log_messages -from ..mixins import MqttEntity, write_state_on_attr_change +from ..mixins import MqttEntity from ..models import ( MessageCallbackType, MqttCommandTemplate, @@ -378,263 +378,248 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): attr: bool = getattr(self, f"_optimistic_{attribute}") return attr + @callback + def _state_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages.""" + payload = self._value_templates[CONF_STATE_VALUE_TEMPLATE]( + msg.payload, PayloadSentinel.NONE + ) + if not payload: + _LOGGER.debug("Ignoring empty state message from '%s'", msg.topic) + return + + if payload == self._payload["on"]: + self._attr_is_on = True + elif payload == self._payload["off"]: + self._attr_is_on = False + elif payload == PAYLOAD_NONE: + self._attr_is_on = None + + @callback + def _brightness_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages for the brightness.""" + payload = self._value_templates[CONF_BRIGHTNESS_VALUE_TEMPLATE]( + msg.payload, PayloadSentinel.DEFAULT + ) + if payload is PayloadSentinel.DEFAULT or not payload: + _LOGGER.debug("Ignoring empty brightness message from '%s'", msg.topic) + return + + device_value = float(payload) + if device_value == 0: + _LOGGER.debug("Ignoring zero brightness from '%s'", msg.topic) + return + + percent_bright = device_value / self._config[CONF_BRIGHTNESS_SCALE] + self._attr_brightness = min(round(percent_bright * 255), 255) + + @callback + def _rgbx_received( + self, + msg: ReceiveMessage, + template: str, + color_mode: ColorMode, + convert_color: Callable[..., tuple[int, ...]], + ) -> tuple[int, ...] | None: + """Process MQTT messages for RGBW and RGBWW.""" + payload = self._value_templates[template](msg.payload, PayloadSentinel.DEFAULT) + if payload is PayloadSentinel.DEFAULT or not payload: + _LOGGER.debug("Ignoring empty %s message from '%s'", color_mode, msg.topic) + return None + color = tuple(int(val) for val in str(payload).split(",")) + if self._optimistic_color_mode: + self._attr_color_mode = color_mode + if self._topic[CONF_BRIGHTNESS_STATE_TOPIC] is None: + rgb = convert_color(*color) + brightness = max(rgb) + if brightness == 0: + _LOGGER.debug( + "Ignoring %s message with zero rgb brightness from '%s'", + color_mode, + msg.topic, + ) + return None + self._attr_brightness = brightness + # Normalize the color to 100% brightness + color = tuple( + min(round(channel / brightness * 255), 255) for channel in color + ) + return color + + @callback + def _rgb_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages for RGB.""" + rgb = self._rgbx_received( + msg, CONF_RGB_VALUE_TEMPLATE, ColorMode.RGB, lambda *x: x + ) + if rgb is None: + return + self._attr_rgb_color = cast(tuple[int, int, int], rgb) + + @callback + def _rgbw_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages for RGBW.""" + rgbw = self._rgbx_received( + msg, + CONF_RGBW_VALUE_TEMPLATE, + ColorMode.RGBW, + color_util.color_rgbw_to_rgb, + ) + if rgbw is None: + return + self._attr_rgbw_color = cast(tuple[int, int, int, int], rgbw) + + @callback + def _rgbww_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages for RGBWW.""" + + @callback + def _converter( + r: int, g: int, b: int, cw: int, ww: int + ) -> tuple[int, int, int]: + min_kelvin = color_util.color_temperature_mired_to_kelvin(self.max_mireds) + max_kelvin = color_util.color_temperature_mired_to_kelvin(self.min_mireds) + return color_util.color_rgbww_to_rgb( + r, g, b, cw, ww, min_kelvin, max_kelvin + ) + + rgbww = self._rgbx_received( + msg, + CONF_RGBWW_VALUE_TEMPLATE, + ColorMode.RGBWW, + _converter, + ) + if rgbww is None: + return + self._attr_rgbww_color = cast(tuple[int, int, int, int, int], rgbww) + + @callback + def _color_mode_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages for color mode.""" + payload = self._value_templates[CONF_COLOR_MODE_VALUE_TEMPLATE]( + msg.payload, PayloadSentinel.DEFAULT + ) + if payload is PayloadSentinel.DEFAULT or not payload: + _LOGGER.debug("Ignoring empty color mode message from '%s'", msg.topic) + return + + self._attr_color_mode = ColorMode(str(payload)) + + @callback + def _color_temp_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages for color temperature.""" + payload = self._value_templates[CONF_COLOR_TEMP_VALUE_TEMPLATE]( + msg.payload, PayloadSentinel.DEFAULT + ) + if payload is PayloadSentinel.DEFAULT or not payload: + _LOGGER.debug("Ignoring empty color temp message from '%s'", msg.topic) + return + + if self._optimistic_color_mode: + self._attr_color_mode = ColorMode.COLOR_TEMP + self._attr_color_temp = int(payload) + + @callback + def _effect_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages for effect.""" + payload = self._value_templates[CONF_EFFECT_VALUE_TEMPLATE]( + msg.payload, PayloadSentinel.DEFAULT + ) + if payload is PayloadSentinel.DEFAULT or not payload: + _LOGGER.debug("Ignoring empty effect message from '%s'", msg.topic) + return + + self._attr_effect = str(payload) + + @callback + def _hs_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages for hs color.""" + payload = self._value_templates[CONF_HS_VALUE_TEMPLATE]( + msg.payload, PayloadSentinel.DEFAULT + ) + if payload is PayloadSentinel.DEFAULT or not payload: + _LOGGER.debug("Ignoring empty hs message from '%s'", msg.topic) + return + try: + hs_color = tuple(float(val) for val in str(payload).split(",", 2)) + if self._optimistic_color_mode: + self._attr_color_mode = ColorMode.HS + self._attr_hs_color = cast(tuple[float, float], hs_color) + except ValueError: + _LOGGER.warning("Failed to parse hs state update: '%s'", payload) + + @callback + def _xy_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages for xy color.""" + payload = self._value_templates[CONF_XY_VALUE_TEMPLATE]( + msg.payload, PayloadSentinel.DEFAULT + ) + if payload is PayloadSentinel.DEFAULT or not payload: + _LOGGER.debug("Ignoring empty xy-color message from '%s'", msg.topic) + return + + xy_color = tuple(float(val) for val in str(payload).split(",", 2)) + if self._optimistic_color_mode: + self._attr_color_mode = ColorMode.XY + self._attr_xy_color = cast(tuple[float, float], xy_color) + def _prepare_subscribe_topics(self) -> None: # noqa: C901 """(Re)Subscribe to topics.""" topics: dict[str, dict[str, Any]] = {} - def add_topic(topic: str, msg_callback: MessageCallbackType) -> None: + def add_topic( + topic: str, msg_callback: MessageCallbackType, tracked_attributes: set[str] + ) -> None: """Add a topic.""" if self._topic[topic] is not None: topics[topic] = { "topic": self._topic[topic], - "msg_callback": msg_callback, + "msg_callback": partial( + self._message_callback, msg_callback, tracked_attributes + ), + "entity_id": self.entity_id, "qos": self._config[CONF_QOS], "encoding": self._config[CONF_ENCODING] or None, } - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_is_on"}) - def state_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages.""" - payload = self._value_templates[CONF_STATE_VALUE_TEMPLATE]( - msg.payload, PayloadSentinel.NONE - ) - if not payload: - _LOGGER.debug("Ignoring empty state message from '%s'", msg.topic) - return - - if payload == self._payload["on"]: - self._attr_is_on = True - elif payload == self._payload["off"]: - self._attr_is_on = False - elif payload == PAYLOAD_NONE: - self._attr_is_on = None - - if self._topic[CONF_STATE_TOPIC] is not None: - topics[CONF_STATE_TOPIC] = { - "topic": self._topic[CONF_STATE_TOPIC], - "msg_callback": state_received, - "qos": self._config[CONF_QOS], - "encoding": self._config[CONF_ENCODING] or None, - } - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_brightness"}) - def brightness_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages for the brightness.""" - payload = self._value_templates[CONF_BRIGHTNESS_VALUE_TEMPLATE]( - msg.payload, PayloadSentinel.DEFAULT - ) - if payload is PayloadSentinel.DEFAULT or not payload: - _LOGGER.debug("Ignoring empty brightness message from '%s'", msg.topic) - return - - device_value = float(payload) - if device_value == 0: - _LOGGER.debug("Ignoring zero brightness from '%s'", msg.topic) - return - - percent_bright = device_value / self._config[CONF_BRIGHTNESS_SCALE] - self._attr_brightness = min(round(percent_bright * 255), 255) - - add_topic(CONF_BRIGHTNESS_STATE_TOPIC, brightness_received) - - @callback - def _rgbx_received( - msg: ReceiveMessage, - template: str, - color_mode: ColorMode, - convert_color: Callable[..., tuple[int, ...]], - ) -> tuple[int, ...] | None: - """Handle new MQTT messages for RGBW and RGBWW.""" - payload = self._value_templates[template]( - msg.payload, PayloadSentinel.DEFAULT - ) - if payload is PayloadSentinel.DEFAULT or not payload: - _LOGGER.debug( - "Ignoring empty %s message from '%s'", color_mode, msg.topic - ) - return None - color = tuple(int(val) for val in str(payload).split(",")) - if self._optimistic_color_mode: - self._attr_color_mode = color_mode - if self._topic[CONF_BRIGHTNESS_STATE_TOPIC] is None: - rgb = convert_color(*color) - brightness = max(rgb) - if brightness == 0: - _LOGGER.debug( - "Ignoring %s message with zero rgb brightness from '%s'", - color_mode, - msg.topic, - ) - return None - self._attr_brightness = brightness - # Normalize the color to 100% brightness - color = tuple( - min(round(channel / brightness * 255), 255) for channel in color - ) - return color - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change( - self, {"_attr_brightness", "_attr_color_mode", "_attr_rgb_color"} + add_topic(CONF_STATE_TOPIC, self._state_received, {"_attr_is_on"}) + add_topic( + CONF_BRIGHTNESS_STATE_TOPIC, self._brightness_received, {"_attr_brightness"} ) - def rgb_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages for RGB.""" - rgb = _rgbx_received( - msg, CONF_RGB_VALUE_TEMPLATE, ColorMode.RGB, lambda *x: x - ) - if rgb is None: - return - self._attr_rgb_color = cast(tuple[int, int, int], rgb) - - add_topic(CONF_RGB_STATE_TOPIC, rgb_received) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change( - self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbw_color"} + add_topic( + CONF_RGB_STATE_TOPIC, + self._rgb_received, + {"_attr_brightness", "_attr_color_mode", "_attr_rgb_color"}, ) - def rgbw_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages for RGBW.""" - rgbw = _rgbx_received( - msg, - CONF_RGBW_VALUE_TEMPLATE, - ColorMode.RGBW, - color_util.color_rgbw_to_rgb, - ) - if rgbw is None: - return - self._attr_rgbw_color = cast(tuple[int, int, int, int], rgbw) - - add_topic(CONF_RGBW_STATE_TOPIC, rgbw_received) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change( - self, {"_attr_brightness", "_attr_color_mode", "_attr_rgbww_color"} + add_topic( + CONF_RGBW_STATE_TOPIC, + self._rgbw_received, + {"_attr_brightness", "_attr_color_mode", "_attr_rgbw_color"}, + ) + add_topic( + CONF_RGBWW_STATE_TOPIC, + self._rgbww_received, + {"_attr_brightness", "_attr_color_mode", "_attr_rgbww_color"}, + ) + add_topic( + CONF_COLOR_MODE_STATE_TOPIC, self._color_mode_received, {"_attr_color_mode"} + ) + add_topic( + CONF_COLOR_TEMP_STATE_TOPIC, + self._color_temp_received, + {"_attr_color_mode", "_attr_color_temp"}, + ) + add_topic(CONF_EFFECT_STATE_TOPIC, self._effect_received, {"_attr_effect"}) + add_topic( + CONF_HS_STATE_TOPIC, + self._hs_received, + {"_attr_color_mode", "_attr_hs_color"}, + ) + add_topic( + CONF_XY_STATE_TOPIC, + self._xy_received, + {"_attr_color_mode", "_attr_xy_color"}, ) - def rgbww_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages for RGBWW.""" - - @callback - def _converter( - r: int, g: int, b: int, cw: int, ww: int - ) -> tuple[int, int, int]: - min_kelvin = color_util.color_temperature_mired_to_kelvin( - self.max_mireds - ) - max_kelvin = color_util.color_temperature_mired_to_kelvin( - self.min_mireds - ) - return color_util.color_rgbww_to_rgb( - r, g, b, cw, ww, min_kelvin, max_kelvin - ) - - rgbww = _rgbx_received( - msg, - CONF_RGBWW_VALUE_TEMPLATE, - ColorMode.RGBWW, - _converter, - ) - if rgbww is None: - return - self._attr_rgbww_color = cast(tuple[int, int, int, int, int], rgbww) - - add_topic(CONF_RGBWW_STATE_TOPIC, rgbww_received) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_color_mode"}) - def color_mode_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages for color mode.""" - payload = self._value_templates[CONF_COLOR_MODE_VALUE_TEMPLATE]( - msg.payload, PayloadSentinel.DEFAULT - ) - if payload is PayloadSentinel.DEFAULT or not payload: - _LOGGER.debug("Ignoring empty color mode message from '%s'", msg.topic) - return - - self._attr_color_mode = ColorMode(str(payload)) - - add_topic(CONF_COLOR_MODE_STATE_TOPIC, color_mode_received) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_color_mode", "_attr_color_temp"}) - def color_temp_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages for color temperature.""" - payload = self._value_templates[CONF_COLOR_TEMP_VALUE_TEMPLATE]( - msg.payload, PayloadSentinel.DEFAULT - ) - if payload is PayloadSentinel.DEFAULT or not payload: - _LOGGER.debug("Ignoring empty color temp message from '%s'", msg.topic) - return - - if self._optimistic_color_mode: - self._attr_color_mode = ColorMode.COLOR_TEMP - self._attr_color_temp = int(payload) - - add_topic(CONF_COLOR_TEMP_STATE_TOPIC, color_temp_received) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_effect"}) - def effect_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages for effect.""" - payload = self._value_templates[CONF_EFFECT_VALUE_TEMPLATE]( - msg.payload, PayloadSentinel.DEFAULT - ) - if payload is PayloadSentinel.DEFAULT or not payload: - _LOGGER.debug("Ignoring empty effect message from '%s'", msg.topic) - return - - self._attr_effect = str(payload) - - add_topic(CONF_EFFECT_STATE_TOPIC, effect_received) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_color_mode", "_attr_hs_color"}) - def hs_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages for hs color.""" - payload = self._value_templates[CONF_HS_VALUE_TEMPLATE]( - msg.payload, PayloadSentinel.DEFAULT - ) - if payload is PayloadSentinel.DEFAULT or not payload: - _LOGGER.debug("Ignoring empty hs message from '%s'", msg.topic) - return - try: - hs_color = tuple(float(val) for val in str(payload).split(",", 2)) - if self._optimistic_color_mode: - self._attr_color_mode = ColorMode.HS - self._attr_hs_color = cast(tuple[float, float], hs_color) - except ValueError: - _LOGGER.warning("Failed to parse hs state update: '%s'", payload) - - add_topic(CONF_HS_STATE_TOPIC, hs_received) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_color_mode", "_attr_xy_color"}) - def xy_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages for xy color.""" - payload = self._value_templates[CONF_XY_VALUE_TEMPLATE]( - msg.payload, PayloadSentinel.DEFAULT - ) - if payload is PayloadSentinel.DEFAULT or not payload: - _LOGGER.debug("Ignoring empty xy-color message from '%s'", msg.topic) - return - - xy_color = tuple(float(val) for val in str(payload).split(",", 2)) - if self._optimistic_color_mode: - self._attr_color_mode = ColorMode.XY - self._attr_xy_color = cast(tuple[float, float], xy_color) - - add_topic(CONF_XY_STATE_TOPIC, xy_received) self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, topics @@ -642,7 +627,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) last_state = await self.async_get_last_state() def restore_state( diff --git a/homeassistant/components/mqtt/light/schema_json.py b/homeassistant/components/mqtt/light/schema_json.py index 52fbf3429b6..f6dec17f8f3 100644 --- a/homeassistant/components/mqtt/light/schema_json.py +++ b/homeassistant/components/mqtt/light/schema_json.py @@ -4,6 +4,7 @@ from __future__ import annotations from collections.abc import Callable from contextlib import suppress +from functools import partial import logging from typing import TYPE_CHECKING, Any, cast @@ -66,8 +67,7 @@ from ..const import ( CONF_STATE_TOPIC, DOMAIN as MQTT_DOMAIN, ) -from ..debug_info import log_messages -from ..mixins import MqttEntity, write_state_on_attr_change +from ..mixins import MqttEntity from ..models import ReceiveMessage from ..schemas import MQTT_ENTITY_COMMON_SCHEMA from ..util import valid_subscribe_topic @@ -414,118 +414,121 @@ class MqttLightJson(MqttEntity, LightEntity, RestoreEntity): self.entity_id, ) + @callback + def _state_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages.""" + values = json_loads_object(msg.payload) + + if values["state"] == "ON": + self._attr_is_on = True + elif values["state"] == "OFF": + self._attr_is_on = False + elif values["state"] is None: + self._attr_is_on = None + + if ( + self._deprecated_color_handling + and color_supported(self.supported_color_modes) + and "color" in values + ): + # Deprecated color handling + if values["color"] is None: + self._attr_hs_color = None + else: + self._update_color(values) + + if not self._deprecated_color_handling and "color_mode" in values: + self._update_color(values) + + if brightness_supported(self.supported_color_modes): + try: + if brightness := values["brightness"]: + if TYPE_CHECKING: + assert isinstance(brightness, float) + self._attr_brightness = color_util.value_to_brightness( + (1, self._config[CONF_BRIGHTNESS_SCALE]), brightness + ) + else: + _LOGGER.debug( + "Ignoring zero brightness value for entity %s", + self.entity_id, + ) + + except KeyError: + pass + except (TypeError, ValueError): + _LOGGER.warning( + "Invalid brightness value '%s' received for entity %s", + values["brightness"], + self.entity_id, + ) + + if ( + self._deprecated_color_handling + and self.supported_color_modes + and ColorMode.COLOR_TEMP in self.supported_color_modes + ): + # Deprecated color handling + try: + if values["color_temp"] is None: + self._attr_color_temp = None + else: + self._attr_color_temp = int(values["color_temp"]) # type: ignore[arg-type] + except KeyError: + pass + except ValueError: + _LOGGER.warning( + "Invalid color temp value '%s' received for entity %s", + values["color_temp"], + self.entity_id, + ) + # Allow to switch back to color_temp + if "color" not in values: + self._attr_hs_color = None + + if self.supported_features and LightEntityFeature.EFFECT: + with suppress(KeyError): + self._attr_effect = cast(str, values["effect"]) + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change( - self, + # + if self._topic[CONF_STATE_TOPIC] is None: + return + + self._sub_state = subscription.async_prepare_subscribe_topics( + self.hass, + self._sub_state, { - "_attr_brightness", - "_attr_color_temp", - "_attr_effect", - "_attr_hs_color", - "_attr_is_on", - "_attr_rgb_color", - "_attr_rgbw_color", - "_attr_rgbww_color", - "_attr_xy_color", - "color_mode", + CONF_STATE_TOPIC: { + "topic": self._topic[CONF_STATE_TOPIC], + "msg_callback": partial( + self._message_callback, + self._state_received, + { + "_attr_brightness", + "_attr_color_temp", + "_attr_effect", + "_attr_hs_color", + "_attr_is_on", + "_attr_rgb_color", + "_attr_rgbw_color", + "_attr_rgbww_color", + "_attr_xy_color", + "color_mode", + }, + ), + "entity_id": self.entity_id, + "qos": self._config[CONF_QOS], + "encoding": self._config[CONF_ENCODING] or None, + } }, ) - def state_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages.""" - values = json_loads_object(msg.payload) - - if values["state"] == "ON": - self._attr_is_on = True - elif values["state"] == "OFF": - self._attr_is_on = False - elif values["state"] is None: - self._attr_is_on = None - - if ( - self._deprecated_color_handling - and color_supported(self.supported_color_modes) - and "color" in values - ): - # Deprecated color handling - if values["color"] is None: - self._attr_hs_color = None - else: - self._update_color(values) - - if not self._deprecated_color_handling and "color_mode" in values: - self._update_color(values) - - if brightness_supported(self.supported_color_modes): - try: - if brightness := values["brightness"]: - if TYPE_CHECKING: - assert isinstance(brightness, float) - self._attr_brightness = color_util.value_to_brightness( - (1, self._config[CONF_BRIGHTNESS_SCALE]), brightness - ) - else: - _LOGGER.debug( - "Ignoring zero brightness value for entity %s", - self.entity_id, - ) - - except KeyError: - pass - except (TypeError, ValueError): - _LOGGER.warning( - "Invalid brightness value '%s' received for entity %s", - values["brightness"], - self.entity_id, - ) - - if ( - self._deprecated_color_handling - and self.supported_color_modes - and ColorMode.COLOR_TEMP in self.supported_color_modes - ): - # Deprecated color handling - try: - if values["color_temp"] is None: - self._attr_color_temp = None - else: - self._attr_color_temp = int(values["color_temp"]) # type: ignore[arg-type] - except KeyError: - pass - except ValueError: - _LOGGER.warning( - "Invalid color temp value '%s' received for entity %s", - values["color_temp"], - self.entity_id, - ) - # Allow to switch back to color_temp - if "color" not in values: - self._attr_hs_color = None - - if self.supported_features and LightEntityFeature.EFFECT: - with suppress(KeyError): - self._attr_effect = cast(str, values["effect"]) - - if self._topic[CONF_STATE_TOPIC] is not None: - self._sub_state = subscription.async_prepare_subscribe_topics( - self.hass, - self._sub_state, - { - "state_topic": { - "topic": self._topic[CONF_STATE_TOPIC], - "msg_callback": state_received, - "qos": self._config[CONF_QOS], - "encoding": self._config[CONF_ENCODING] or None, - } - }, - ) async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) last_state = await self.async_get_last_state() if self._optimistic and last_state: diff --git a/homeassistant/components/mqtt/light/schema_template.py b/homeassistant/components/mqtt/light/schema_template.py index 651b691e28e..193b4d23931 100644 --- a/homeassistant/components/mqtt/light/schema_template.py +++ b/homeassistant/components/mqtt/light/schema_template.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial import logging from typing import Any @@ -44,8 +45,7 @@ from ..const import ( CONF_STATE_TOPIC, PAYLOAD_NONE, ) -from ..debug_info import log_messages -from ..mixins import MqttEntity, write_state_on_attr_change +from ..mixins import MqttEntity from ..models import ( MqttCommandTemplate, MqttValueTemplate, @@ -188,107 +188,107 @@ class MqttLightTemplate(MqttEntity, LightEntity, RestoreEntity): # Support for ct + hs, prioritize hs self._attr_color_mode = ColorMode.HS if self.hs_color else ColorMode.COLOR_TEMP + @callback + def _state_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages.""" + state = self._value_templates[CONF_STATE_TEMPLATE](msg.payload) + if state == STATE_ON: + self._attr_is_on = True + elif state == STATE_OFF: + self._attr_is_on = False + elif state == PAYLOAD_NONE: + self._attr_is_on = None + else: + _LOGGER.warning("Invalid state value received") + + if CONF_BRIGHTNESS_TEMPLATE in self._config: + try: + if brightness := int( + self._value_templates[CONF_BRIGHTNESS_TEMPLATE](msg.payload) + ): + self._attr_brightness = brightness + else: + _LOGGER.debug( + "Ignoring zero brightness value for entity %s", + self.entity_id, + ) + + except ValueError: + _LOGGER.warning("Invalid brightness value received from %s", msg.topic) + + if CONF_COLOR_TEMP_TEMPLATE in self._config: + try: + color_temp = self._value_templates[CONF_COLOR_TEMP_TEMPLATE]( + msg.payload + ) + self._attr_color_temp = ( + int(color_temp) if color_temp != "None" else None + ) + except ValueError: + _LOGGER.warning("Invalid color temperature value received") + + if ( + CONF_RED_TEMPLATE in self._config + and CONF_GREEN_TEMPLATE in self._config + and CONF_BLUE_TEMPLATE in self._config + ): + try: + red = self._value_templates[CONF_RED_TEMPLATE](msg.payload) + green = self._value_templates[CONF_GREEN_TEMPLATE](msg.payload) + blue = self._value_templates[CONF_BLUE_TEMPLATE](msg.payload) + if red == "None" and green == "None" and blue == "None": + self._attr_hs_color = None + else: + self._attr_hs_color = color_util.color_RGB_to_hs( + int(red), int(green), int(blue) + ) + self._update_color_mode() + except ValueError: + _LOGGER.warning("Invalid color value received") + + if CONF_EFFECT_TEMPLATE in self._config: + effect = str(self._value_templates[CONF_EFFECT_TEMPLATE](msg.payload)) + if ( + effect_list := self._config[CONF_EFFECT_LIST] + ) and effect in effect_list: + self._attr_effect = effect + else: + _LOGGER.warning("Unsupported effect value received") + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change( - self, + if self._topics[CONF_STATE_TOPIC] is None: + return + + self._sub_state = subscription.async_prepare_subscribe_topics( + self.hass, + self._sub_state, { - "_attr_brightness", - "_attr_color_mode", - "_attr_color_temp", - "_attr_effect", - "_attr_hs_color", - "_attr_is_on", + "state_topic": { + "topic": self._topics[CONF_STATE_TOPIC], + "msg_callback": partial( + self._message_callback, + self._state_received, + { + "_attr_brightness", + "_attr_color_mode", + "_attr_color_temp", + "_attr_effect", + "_attr_hs_color", + "_attr_is_on", + }, + ), + "entity_id": self.entity_id, + "qos": self._config[CONF_QOS], + "encoding": self._config[CONF_ENCODING] or None, + } }, ) - def state_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages.""" - state = self._value_templates[CONF_STATE_TEMPLATE](msg.payload) - if state == STATE_ON: - self._attr_is_on = True - elif state == STATE_OFF: - self._attr_is_on = False - elif state == PAYLOAD_NONE: - self._attr_is_on = None - else: - _LOGGER.warning("Invalid state value received") - - if CONF_BRIGHTNESS_TEMPLATE in self._config: - try: - if brightness := int( - self._value_templates[CONF_BRIGHTNESS_TEMPLATE](msg.payload) - ): - self._attr_brightness = brightness - else: - _LOGGER.debug( - "Ignoring zero brightness value for entity %s", - self.entity_id, - ) - - except ValueError: - _LOGGER.warning( - "Invalid brightness value received from %s", msg.topic - ) - - if CONF_COLOR_TEMP_TEMPLATE in self._config: - try: - color_temp = self._value_templates[CONF_COLOR_TEMP_TEMPLATE]( - msg.payload - ) - self._attr_color_temp = ( - int(color_temp) if color_temp != "None" else None - ) - except ValueError: - _LOGGER.warning("Invalid color temperature value received") - - if ( - CONF_RED_TEMPLATE in self._config - and CONF_GREEN_TEMPLATE in self._config - and CONF_BLUE_TEMPLATE in self._config - ): - try: - red = self._value_templates[CONF_RED_TEMPLATE](msg.payload) - green = self._value_templates[CONF_GREEN_TEMPLATE](msg.payload) - blue = self._value_templates[CONF_BLUE_TEMPLATE](msg.payload) - if red == "None" and green == "None" and blue == "None": - self._attr_hs_color = None - else: - self._attr_hs_color = color_util.color_RGB_to_hs( - int(red), int(green), int(blue) - ) - self._update_color_mode() - except ValueError: - _LOGGER.warning("Invalid color value received") - - if CONF_EFFECT_TEMPLATE in self._config: - effect = str(self._value_templates[CONF_EFFECT_TEMPLATE](msg.payload)) - if ( - effect_list := self._config[CONF_EFFECT_LIST] - ) and effect in effect_list: - self._attr_effect = effect - else: - _LOGGER.warning("Unsupported effect value received") - - if self._topics[CONF_STATE_TOPIC] is not None: - self._sub_state = subscription.async_prepare_subscribe_topics( - self.hass, - self._sub_state, - { - "state_topic": { - "topic": self._topics[CONF_STATE_TOPIC], - "msg_callback": state_received, - "qos": self._config[CONF_QOS], - "encoding": self._config[CONF_ENCODING] or None, - } - }, - ) async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) last_state = await self.async_get_last_state() if self._optimistic and last_state: diff --git a/homeassistant/components/mqtt/lock.py b/homeassistant/components/mqtt/lock.py index 3dfd2b2e6d2..52c2bea2cc3 100644 --- a/homeassistant/components/mqtt/lock.py +++ b/homeassistant/components/mqtt/lock.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial import logging import re from typing import Any @@ -36,12 +37,7 @@ from .const import ( CONF_STATE_OPENING, CONF_STATE_TOPIC, ) -from .debug_info import log_messages -from .mixins import ( - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import ( MqttCommandTemplate, MqttValueTemplate, @@ -186,57 +182,58 @@ class MqttLock(MqttEntity, LockEntity): self._valid_states = [config[state] for state in STATE_CONFIG_KEYS] + @callback + def _message_received(self, msg: ReceiveMessage) -> None: + """Handle new lock state messages.""" + payload = self._value_template(msg.payload) + if not payload.strip(): # No output from template, ignore + _LOGGER.debug( + "Ignoring empty payload '%s' after rendering for topic %s", + payload, + msg.topic, + ) + return + if payload == self._config[CONF_PAYLOAD_RESET]: + # Reset the state to `unknown` + self._attr_is_locked = None + elif payload in self._valid_states: + self._attr_is_locked = payload == self._config[CONF_STATE_LOCKED] + self._attr_is_locking = payload == self._config[CONF_STATE_LOCKING] + self._attr_is_open = payload == self._config[CONF_STATE_OPEN] + self._attr_is_opening = payload == self._config[CONF_STATE_OPENING] + self._attr_is_unlocking = payload == self._config[CONF_STATE_UNLOCKING] + self._attr_is_jammed = payload == self._config[CONF_STATE_JAMMED] + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - - topics: dict[str, dict[str, Any]] = {} + topics: dict[str, dict[str, Any]] qos: int = self._config[CONF_QOS] encoding: str | None = self._config[CONF_ENCODING] or None - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change( - self, - { - "_attr_is_jammed", - "_attr_is_locked", - "_attr_is_locking", - "_attr_is_open", - "_attr_is_opening", - "_attr_is_unlocking", - }, - ) - def message_received(msg: ReceiveMessage) -> None: - """Handle new lock state messages.""" - payload = self._value_template(msg.payload) - if not payload.strip(): # No output from template, ignore - _LOGGER.debug( - "Ignoring empty payload '%s' after rendering for topic %s", - payload, - msg.topic, - ) - return - if payload == self._config[CONF_PAYLOAD_RESET]: - # Reset the state to `unknown` - self._attr_is_locked = None - elif payload in self._valid_states: - self._attr_is_locked = payload == self._config[CONF_STATE_LOCKED] - self._attr_is_locking = payload == self._config[CONF_STATE_LOCKING] - self._attr_is_open = payload == self._config[CONF_STATE_OPEN] - self._attr_is_opening = payload == self._config[CONF_STATE_OPENING] - self._attr_is_unlocking = payload == self._config[CONF_STATE_UNLOCKING] - self._attr_is_jammed = payload == self._config[CONF_STATE_JAMMED] - if self._config.get(CONF_STATE_TOPIC) is None: # Force into optimistic mode. self._optimistic = True - else: - topics[CONF_STATE_TOPIC] = { + return + topics = { + CONF_STATE_TOPIC: { "topic": self._config.get(CONF_STATE_TOPIC), - "msg_callback": message_received, + "msg_callback": partial( + self._message_callback, + self._message_received, + { + "_attr_is_jammed", + "_attr_is_locked", + "_attr_is_locking", + "_attr_is_open", + "_attr_is_opening", + "_attr_is_unlocking", + }, + ), + "entity_id": self.entity_id, CONF_QOS: qos, CONF_ENCODING: encoding, } + } self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, @@ -246,7 +243,7 @@ class MqttLock(MqttEntity, LockEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_lock(self, **kwargs: Any) -> None: """Lock the device. diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index 8d294a45e97..0331b49c2a6 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -114,7 +114,7 @@ from .models import ( from .subscription import ( EntitySubscription, async_prepare_subscribe_topics, - async_subscribe_topics, + async_subscribe_topics_internal, async_unsubscribe_topics, ) from .util import mqtt_config_entry_enabled @@ -413,7 +413,7 @@ class MqttAttributesMixin(Entity): """Subscribe MQTT events.""" await super().async_added_to_hass() self._attributes_prepare_subscribe_topics() - await self._attributes_subscribe_topics() + self._attributes_subscribe_topics() def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None: """Handle updated discovery message.""" @@ -422,7 +422,7 @@ class MqttAttributesMixin(Entity): async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None: """Handle updated discovery message.""" - await self._attributes_subscribe_topics() + self._attributes_subscribe_topics() def _attributes_prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" @@ -447,9 +447,10 @@ class MqttAttributesMixin(Entity): }, ) - async def _attributes_subscribe_topics(self) -> None: + @callback + def _attributes_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await async_subscribe_topics(self.hass, self._attributes_sub_state) + async_subscribe_topics_internal(self.hass, self._attributes_sub_state) async def async_will_remove_from_hass(self) -> None: """Unsubscribe when removed.""" @@ -494,7 +495,7 @@ class MqttAvailabilityMixin(Entity): """Subscribe MQTT events.""" await super().async_added_to_hass() self._availability_prepare_subscribe_topics() - await self._availability_subscribe_topics() + self._availability_subscribe_topics() self.async_on_remove( async_dispatcher_connect(self.hass, MQTT_CONNECTED, self.async_mqtt_connect) ) @@ -511,7 +512,7 @@ class MqttAvailabilityMixin(Entity): async def availability_discovery_update(self, config: DiscoveryInfoType) -> None: """Handle updated discovery message.""" - await self._availability_subscribe_topics() + self._availability_subscribe_topics() def _availability_setup_from_config(self, config: ConfigType) -> None: """(Re)Setup.""" @@ -579,9 +580,10 @@ class MqttAvailabilityMixin(Entity): self._available[topic] = False self._available_latest = False - async def _availability_subscribe_topics(self) -> None: + @callback + def _availability_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await async_subscribe_topics(self.hass, self._availability_sub_state) + async_subscribe_topics_internal(self.hass, self._availability_sub_state) @callback def async_mqtt_connect(self) -> None: @@ -1254,13 +1256,15 @@ class MqttEntity( def _message_callback( self, msg_callback: MessageCallbackType, - attributes: set[str], + attributes: set[str] | None, msg: ReceiveMessage, ) -> None: """Process the message callback.""" - attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple( - (attribute, getattr(self, attribute, UNDEFINED)) for attribute in attributes - ) + if attributes is not None: + attrs_snapshot: tuple[tuple[str, Any | UndefinedType], ...] = tuple( + (attribute, getattr(self, attribute, UNDEFINED)) + for attribute in attributes + ) mqtt_data = self.hass.data[DATA_MQTT] messages = mqtt_data.debug_info_entities[self.entity_id]["subscriptions"][ msg.subscribed_topic @@ -1274,7 +1278,7 @@ class MqttEntity( _LOGGER.warning(exc) return - if self._attrs_have_changed(attrs_snapshot): + if attributes is not None and self._attrs_have_changed(attrs_snapshot): mqtt_data.state_write_requests.write_state_request(self) diff --git a/homeassistant/components/mqtt/models.py b/homeassistant/components/mqtt/models.py index 18ba6718e9d..69d02fa4a24 100644 --- a/homeassistant/components/mqtt/models.py +++ b/homeassistant/components/mqtt/models.py @@ -5,7 +5,7 @@ from __future__ import annotations from ast import literal_eval import asyncio from collections import deque -from collections.abc import Callable, Coroutine +from collections.abc import Callable from dataclasses import dataclass, field from enum import StrEnum import logging @@ -70,7 +70,6 @@ class ReceiveMessage: timestamp: float -type AsyncMessageCallbackType = Callable[[ReceiveMessage], Coroutine[Any, Any, None]] type MessageCallbackType = Callable[[ReceiveMessage], None] diff --git a/homeassistant/components/mqtt/number.py b/homeassistant/components/mqtt/number.py index 74d768ae598..17e7cfe69e0 100644 --- a/homeassistant/components/mqtt/number.py +++ b/homeassistant/components/mqtt/number.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial import logging import voluptuous as vol @@ -41,12 +42,7 @@ from .const import ( CONF_RETAIN, CONF_STATE_TOPIC, ) -from .debug_info import log_messages -from .mixins import ( - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import ( MqttCommandTemplate, MqttValueTemplate, @@ -165,64 +161,66 @@ class MqttNumber(MqttEntity, RestoreNumber): self._attr_native_step = config[CONF_STEP] self._attr_native_unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT) + @callback + def _message_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages.""" + num_value: int | float | None + payload = str(self._value_template(msg.payload)) + if not payload.strip(): + _LOGGER.debug("Ignoring empty state update from '%s'", msg.topic) + return + try: + if payload == self._config[CONF_PAYLOAD_RESET]: + num_value = None + elif payload.isnumeric(): + num_value = int(payload) + else: + num_value = float(payload) + except ValueError: + _LOGGER.warning("Payload '%s' is not a Number", msg.payload) + return + + if num_value is not None and ( + num_value < self.min_value or num_value > self.max_value + ): + _LOGGER.error( + "Invalid value for %s: %s (range %s - %s)", + self.entity_id, + num_value, + self.min_value, + self.max_value, + ) + return + + self._attr_native_value = num_value + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_native_value"}) - def message_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages.""" - num_value: int | float | None - payload = str(self._value_template(msg.payload)) - if not payload.strip(): - _LOGGER.debug("Ignoring empty state update from '%s'", msg.topic) - return - try: - if payload == self._config[CONF_PAYLOAD_RESET]: - num_value = None - elif payload.isnumeric(): - num_value = int(payload) - else: - num_value = float(payload) - except ValueError: - _LOGGER.warning("Payload '%s' is not a Number", msg.payload) - return - - if num_value is not None and ( - num_value < self.min_value or num_value > self.max_value - ): - _LOGGER.error( - "Invalid value for %s: %s (range %s - %s)", - self.entity_id, - num_value, - self.min_value, - self.max_value, - ) - return - - self._attr_native_value = num_value - if self._config.get(CONF_STATE_TOPIC) is None: # Force into optimistic mode. self._attr_assumed_state = True - else: - self._sub_state = subscription.async_prepare_subscribe_topics( - self.hass, - self._sub_state, - { - "state_topic": { - "topic": self._config.get(CONF_STATE_TOPIC), - "msg_callback": message_received, - "qos": self._config[CONF_QOS], - "encoding": self._config[CONF_ENCODING] or None, - } - }, - ) + return + self._sub_state = subscription.async_prepare_subscribe_topics( + self.hass, + self._sub_state, + { + "state_topic": { + "topic": self._config.get(CONF_STATE_TOPIC), + "msg_callback": partial( + self._message_callback, + self._message_received, + {"_attr_native_value"}, + ), + "entity_id": self.entity_id, + "qos": self._config[CONF_QOS], + "encoding": self._config[CONF_ENCODING] or None, + } + }, + ) async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) if self._attr_assumed_state and ( last_number_data := await self.async_get_last_number_data() diff --git a/homeassistant/components/mqtt/select.py b/homeassistant/components/mqtt/select.py index 05df697764d..a2814055a7c 100644 --- a/homeassistant/components/mqtt/select.py +++ b/homeassistant/components/mqtt/select.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial import logging import voluptuous as vol @@ -27,12 +28,7 @@ from .const import ( CONF_RETAIN, CONF_STATE_TOPIC, ) -from .debug_info import log_messages -from .mixins import ( - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import ( MqttCommandTemplate, MqttValueTemplate, @@ -113,56 +109,58 @@ class MqttSelect(MqttEntity, SelectEntity, RestoreEntity): config.get(CONF_VALUE_TEMPLATE), entity=self ).async_render_with_possible_json_value + @callback + def _message_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT messages.""" + payload = str(self._value_template(msg.payload)) + if not payload.strip(): # No output from template, ignore + _LOGGER.debug( + "Ignoring empty payload '%s' after rendering for topic %s", + payload, + msg.topic, + ) + return + if payload.lower() == "none": + self._attr_current_option = None + return + + if payload not in self.options: + _LOGGER.error( + "Invalid option for %s: '%s' (valid options: %s)", + self.entity_id, + payload, + self.options, + ) + return + self._attr_current_option = payload + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_current_option"}) - def message_received(msg: ReceiveMessage) -> None: - """Handle new MQTT messages.""" - payload = str(self._value_template(msg.payload)) - if not payload.strip(): # No output from template, ignore - _LOGGER.debug( - "Ignoring empty payload '%s' after rendering for topic %s", - payload, - msg.topic, - ) - return - if payload.lower() == "none": - self._attr_current_option = None - return - - if payload not in self.options: - _LOGGER.error( - "Invalid option for %s: '%s' (valid options: %s)", - self.entity_id, - payload, - self.options, - ) - return - self._attr_current_option = payload - if self._config.get(CONF_STATE_TOPIC) is None: # Force into optimistic mode. self._attr_assumed_state = True - else: - self._sub_state = subscription.async_prepare_subscribe_topics( - self.hass, - self._sub_state, - { - "state_topic": { - "topic": self._config.get(CONF_STATE_TOPIC), - "msg_callback": message_received, - "qos": self._config[CONF_QOS], - "encoding": self._config[CONF_ENCODING] or None, - } - }, - ) + return + self._sub_state = subscription.async_prepare_subscribe_topics( + self.hass, + self._sub_state, + { + "state_topic": { + "topic": self._config.get(CONF_STATE_TOPIC), + "msg_callback": partial( + self._message_callback, + self._message_received, + {"_attr_current_option"}, + ), + "entity_id": self.entity_id, + "qos": self._config[CONF_QOS], + "encoding": self._config[CONF_ENCODING] or None, + } + }, + ) async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) if self._attr_assumed_state and ( last_state := await self.async_get_last_state() diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py index d37da597ffb..c8fe932ed71 100644 --- a/homeassistant/components/mqtt/sensor.py +++ b/homeassistant/components/mqtt/sensor.py @@ -305,7 +305,7 @@ class MqttSensor(MqttEntity, RestoreSensor): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) @callback def _value_is_expired(self, *_: datetime) -> None: diff --git a/homeassistant/components/mqtt/siren.py b/homeassistant/components/mqtt/siren.py index 9188e3d03ae..06cb2677c09 100644 --- a/homeassistant/components/mqtt/siren.py +++ b/homeassistant/components/mqtt/siren.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial import logging from typing import Any, cast @@ -48,12 +49,7 @@ from .const import ( PAYLOAD_EMPTY_JSON, PAYLOAD_NONE, ) -from .debug_info import log_messages -from .mixins import ( - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import ( MqttCommandTemplate, MqttValueTemplate, @@ -205,92 +201,94 @@ class MqttSiren(MqttEntity, SirenEntity): entity=self, ).async_render_with_possible_json_value - def _prepare_subscribe_topics(self) -> None: - """(Re)Subscribe to topics.""" - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_is_on", "_extra_attributes"}) - def state_message_received(msg: ReceiveMessage) -> None: - """Handle new MQTT state messages.""" - payload = self._value_template(msg.payload) - if not payload or payload == PAYLOAD_EMPTY_JSON: + @callback + def _state_message_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT state messages.""" + payload = self._value_template(msg.payload) + if not payload or payload == PAYLOAD_EMPTY_JSON: + _LOGGER.debug( + "Ignoring empty payload '%s' after rendering for topic %s", + payload, + msg.topic, + ) + return + json_payload: dict[str, Any] = {} + if payload in [self._state_on, self._state_off, PAYLOAD_NONE]: + json_payload = {STATE: payload} + else: + try: + json_payload = json_loads_object(payload) _LOGGER.debug( - "Ignoring empty payload '%s' after rendering for topic %s", - payload, + ( + "JSON payload detected after processing payload '%s' on" + " topic %s" + ), + json_payload, + msg.topic, + ) + except JSON_DECODE_EXCEPTIONS: + _LOGGER.warning( + ( + "No valid (JSON) payload detected after processing payload" + " '%s' on topic %s" + ), + json_payload, msg.topic, ) return - json_payload: dict[str, Any] = {} - if payload in [self._state_on, self._state_off, PAYLOAD_NONE]: - json_payload = {STATE: payload} - else: - try: - json_payload = json_loads_object(payload) - _LOGGER.debug( - ( - "JSON payload detected after processing payload '%s' on" - " topic %s" - ), - json_payload, - msg.topic, - ) - except JSON_DECODE_EXCEPTIONS: - _LOGGER.warning( - ( - "No valid (JSON) payload detected after processing payload" - " '%s' on topic %s" - ), - json_payload, - msg.topic, - ) - return - if STATE in json_payload: - if json_payload[STATE] == self._state_on: - self._attr_is_on = True - if json_payload[STATE] == self._state_off: - self._attr_is_on = False - if json_payload[STATE] == PAYLOAD_NONE: - self._attr_is_on = None - del json_payload[STATE] + if STATE in json_payload: + if json_payload[STATE] == self._state_on: + self._attr_is_on = True + if json_payload[STATE] == self._state_off: + self._attr_is_on = False + if json_payload[STATE] == PAYLOAD_NONE: + self._attr_is_on = None + del json_payload[STATE] - if json_payload: - # process attributes - try: - params: SirenTurnOnServiceParameters - params = vol.All(TURN_ON_SCHEMA)(json_payload) - except vol.MultipleInvalid as invalid_siren_parameters: - _LOGGER.warning( - "Unable to update siren state attributes from payload '%s': %s", - json_payload, - invalid_siren_parameters, - ) - return - # To be able to track changes to self._extra_attributes we assign - # a fresh copy to make the original tracked reference immutable. - self._extra_attributes = dict(self._extra_attributes) - self._update(process_turn_on_params(self, params)) + if json_payload: + # process attributes + try: + params: SirenTurnOnServiceParameters + params = vol.All(TURN_ON_SCHEMA)(json_payload) + except vol.MultipleInvalid as invalid_siren_parameters: + _LOGGER.warning( + "Unable to update siren state attributes from payload '%s': %s", + json_payload, + invalid_siren_parameters, + ) + return + # To be able to track changes to self._extra_attributes we assign + # a fresh copy to make the original tracked reference immutable. + self._extra_attributes = dict(self._extra_attributes) + self._update(process_turn_on_params(self, params)) + def _prepare_subscribe_topics(self) -> None: + """(Re)Subscribe to topics.""" if self._config.get(CONF_STATE_TOPIC) is None: # Force into optimistic mode. self._optimistic = True - else: - self._sub_state = subscription.async_prepare_subscribe_topics( - self.hass, - self._sub_state, - { - CONF_STATE_TOPIC: { - "topic": self._config.get(CONF_STATE_TOPIC), - "msg_callback": state_message_received, - "qos": self._config[CONF_QOS], - "encoding": self._config[CONF_ENCODING] or None, - } - }, - ) + return + self._sub_state = subscription.async_prepare_subscribe_topics( + self.hass, + self._sub_state, + { + CONF_STATE_TOPIC: { + "topic": self._config.get(CONF_STATE_TOPIC), + "msg_callback": partial( + self._message_callback, + self._state_message_received, + {"_attr_is_on", "_extra_attributes"}, + ), + "entity_id": self.entity_id, + "qos": self._config[CONF_QOS], + "encoding": self._config[CONF_ENCODING] or None, + } + }, + ) async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) @property def extra_state_attributes(self) -> dict[str, Any] | None: diff --git a/homeassistant/components/mqtt/subscription.py b/homeassistant/components/mqtt/subscription.py index d0dc98484b3..9e3ea21222f 100644 --- a/homeassistant/components/mqtt/subscription.py +++ b/homeassistant/components/mqtt/subscription.py @@ -2,14 +2,15 @@ from __future__ import annotations -from collections.abc import Callable, Coroutine +from collections.abc import Callable from dataclasses import dataclass +from functools import partial from typing import TYPE_CHECKING, Any -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback -from .. import mqtt from . import debug_info +from .client import async_subscribe_internal from .const import DEFAULT_QOS from .models import MessageCallbackType @@ -21,7 +22,7 @@ class EntitySubscription: hass: HomeAssistant topic: str | None message_callback: MessageCallbackType - subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None + should_subscribe: bool | None unsubscribe_callback: Callable[[], None] | None qos: int = 0 encoding: str = "utf-8" @@ -53,15 +54,16 @@ class EntitySubscription: self.hass, self.message_callback, self.topic, self.entity_id ) - self.subscribe_task = mqtt.async_subscribe( - hass, self.topic, self.message_callback, self.qos, self.encoding - ) + self.should_subscribe = True - async def subscribe(self) -> None: + @callback + def subscribe(self) -> None: """Subscribe to a topic.""" - if not self.subscribe_task: + if not self.should_subscribe or not self.topic: return - self.unsubscribe_callback = await self.subscribe_task + self.unsubscribe_callback = async_subscribe_internal( + self.hass, self.topic, self.message_callback, self.qos, self.encoding + ) def _should_resubscribe(self, other: EntitySubscription | None) -> bool: """Check if we should re-subscribe to the topic using the old state.""" @@ -79,6 +81,7 @@ class EntitySubscription: ) +@callback def async_prepare_subscribe_topics( hass: HomeAssistant, new_state: dict[str, EntitySubscription] | None, @@ -107,7 +110,7 @@ def async_prepare_subscribe_topics( qos=value.get("qos", DEFAULT_QOS), encoding=value.get("encoding", "utf-8"), hass=hass, - subscribe_task=None, + should_subscribe=None, entity_id=value.get("entity_id", None), ) # Get the current subscription state @@ -135,12 +138,29 @@ async def async_subscribe_topics( sub_state: dict[str, EntitySubscription], ) -> None: """(Re)Subscribe to a set of MQTT topics.""" + async_subscribe_topics_internal(hass, sub_state) + + +@callback +def async_subscribe_topics_internal( + hass: HomeAssistant, + sub_state: dict[str, EntitySubscription], +) -> None: + """(Re)Subscribe to a set of MQTT topics. + + This function is internal to the MQTT integration and should not be called + from outside the integration. + """ for sub in sub_state.values(): - await sub.subscribe() + sub.subscribe() -def async_unsubscribe_topics( - hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None -) -> dict[str, EntitySubscription]: - """Unsubscribe from all MQTT topics managed by async_subscribe_topics.""" - return async_prepare_subscribe_topics(hass, sub_state, {}) +if TYPE_CHECKING: + + def async_unsubscribe_topics( + hass: HomeAssistant, sub_state: dict[str, EntitySubscription] | None + ) -> dict[str, EntitySubscription]: + """Unsubscribe from all MQTT topics managed by async_subscribe_topics.""" + + +async_unsubscribe_topics = partial(async_prepare_subscribe_topics, topics={}) diff --git a/homeassistant/components/mqtt/switch.py b/homeassistant/components/mqtt/switch.py index 5cbfefe0111..9f266a0e9ab 100644 --- a/homeassistant/components/mqtt/switch.py +++ b/homeassistant/components/mqtt/switch.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial from typing import Any import voluptuous as vol @@ -36,12 +37,7 @@ from .const import ( CONF_STATE_TOPIC, PAYLOAD_NONE, ) -from .debug_info import log_messages -from .mixins import ( - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import MqttValueTemplate, ReceiveMessage from .schemas import MQTT_ENTITY_COMMON_SCHEMA @@ -118,42 +114,44 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity): self._config.get(CONF_VALUE_TEMPLATE), entity=self ).async_render_with_possible_json_value + @callback + def _state_message_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT state messages.""" + payload = self._value_template(msg.payload) + if payload == self._state_on: + self._attr_is_on = True + elif payload == self._state_off: + self._attr_is_on = False + elif payload == PAYLOAD_NONE: + self._attr_is_on = None + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_is_on"}) - def state_message_received(msg: ReceiveMessage) -> None: - """Handle new MQTT state messages.""" - payload = self._value_template(msg.payload) - if payload == self._state_on: - self._attr_is_on = True - elif payload == self._state_off: - self._attr_is_on = False - elif payload == PAYLOAD_NONE: - self._attr_is_on = None - if self._config.get(CONF_STATE_TOPIC) is None: # Force into optimistic mode. self._optimistic = True - else: - self._sub_state = subscription.async_prepare_subscribe_topics( - self.hass, - self._sub_state, - { - CONF_STATE_TOPIC: { - "topic": self._config.get(CONF_STATE_TOPIC), - "msg_callback": state_message_received, - "qos": self._config[CONF_QOS], - "encoding": self._config[CONF_ENCODING] or None, - } - }, - ) + return + self._sub_state = subscription.async_prepare_subscribe_topics( + self.hass, + self._sub_state, + { + CONF_STATE_TOPIC: { + "topic": self._config.get(CONF_STATE_TOPIC), + "msg_callback": partial( + self._message_callback, + self._state_message_received, + {"_attr_is_on"}, + ), + "entity_id": self.entity_id, + "qos": self._config[CONF_QOS], + "encoding": self._config[CONF_ENCODING] or None, + } + }, + ) async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) if self._optimistic and (last_state := await self.async_get_last_state()): self._attr_is_on = last_state.state == STATE_ON diff --git a/homeassistant/components/mqtt/tag.py b/homeassistant/components/mqtt/tag.py index 4ecf0862827..55f7e775ae9 100644 --- a/homeassistant/components/mqtt/tag.py +++ b/homeassistant/components/mqtt/tag.py @@ -167,7 +167,7 @@ class MQTTTagScanner(MqttDiscoveryDeviceUpdateMixin): } }, ) - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_tear_down(self) -> None: """Cleanup tag scanner.""" diff --git a/homeassistant/components/mqtt/text.py b/homeassistant/components/mqtt/text.py index c9b0a6c9d70..abced8b8744 100644 --- a/homeassistant/components/mqtt/text.py +++ b/homeassistant/components/mqtt/text.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial import logging import re from typing import Any @@ -34,12 +35,7 @@ from .const import ( CONF_RETAIN, CONF_STATE_TOPIC, ) -from .debug_info import log_messages -from .mixins import ( - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import ( MessageCallbackType, MqttCommandTemplate, @@ -160,32 +156,41 @@ class MqttTextEntity(MqttEntity, TextEntity): self._optimistic = optimistic or config.get(CONF_STATE_TOPIC) is None self._attr_assumed_state = bool(self._optimistic) + @callback + def _handle_state_message_received(self, msg: ReceiveMessage) -> None: + """Handle receiving state message via MQTT.""" + payload = str(self._value_template(msg.payload)) + if check_state_too_long(_LOGGER, payload, self.entity_id, msg): + return + self._attr_native_value = payload + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" topics: dict[str, Any] = {} def add_subscription( - topics: dict[str, Any], topic: str, msg_callback: MessageCallbackType + topics: dict[str, Any], + topic: str, + msg_callback: MessageCallbackType, + tracked_attributes: set[str], ) -> None: if self._config.get(topic) is not None: topics[topic] = { "topic": self._config[topic], - "msg_callback": msg_callback, + "msg_callback": partial( + self._message_callback, msg_callback, tracked_attributes + ), + "entity_id": self.entity_id, "qos": self._config[CONF_QOS], "encoding": self._config[CONF_ENCODING] or None, } - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_native_value"}) - def handle_state_message_received(msg: ReceiveMessage) -> None: - """Handle receiving state message via MQTT.""" - payload = str(self._value_template(msg.payload)) - if check_state_too_long(_LOGGER, payload, self.entity_id, msg): - return - self._attr_native_value = payload - - add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received) + add_subscription( + topics, + CONF_STATE_TOPIC, + self._handle_state_message_received, + {"_attr_native_value"}, + ) self._sub_state = subscription.async_prepare_subscribe_topics( self.hass, self._sub_state, topics @@ -193,7 +198,7 @@ class MqttTextEntity(MqttEntity, TextEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_set_value(self, value: str) -> None: """Change the text.""" diff --git a/homeassistant/components/mqtt/update.py b/homeassistant/components/mqtt/update.py index 25cc60155a0..ee29601e585 100644 --- a/homeassistant/components/mqtt/update.py +++ b/homeassistant/components/mqtt/update.py @@ -2,6 +2,7 @@ from __future__ import annotations +from functools import partial import logging from typing import Any, TypedDict, cast @@ -32,12 +33,7 @@ from .const import ( CONF_STATE_TOPIC, PAYLOAD_EMPTY_JSON, ) -from .debug_info import log_messages -from .mixins import ( - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import MessageCallbackType, MqttValueTemplate, ReceiveMessage from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .util import valid_publish_topic, valid_subscribe_topic @@ -141,25 +137,104 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): ).async_render_with_possible_json_value, } + @callback + def _handle_state_message_received(self, msg: ReceiveMessage) -> None: + """Handle receiving state message via MQTT.""" + payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload) + + if not payload or payload == PAYLOAD_EMPTY_JSON: + _LOGGER.debug( + "Ignoring empty payload '%s' after rendering for topic %s", + payload, + msg.topic, + ) + return + + json_payload: _MqttUpdatePayloadType = {} + try: + rendered_json_payload = json_loads(payload) + if isinstance(rendered_json_payload, dict): + _LOGGER.debug( + ( + "JSON payload detected after processing payload '%s' on" + " topic %s" + ), + rendered_json_payload, + msg.topic, + ) + json_payload = cast(_MqttUpdatePayloadType, rendered_json_payload) + else: + _LOGGER.debug( + ( + "Non-dictionary JSON payload detected after processing" + " payload '%s' on topic %s" + ), + payload, + msg.topic, + ) + json_payload = {"installed_version": str(payload)} + except JSON_DECODE_EXCEPTIONS: + _LOGGER.debug( + ( + "No valid (JSON) payload detected after processing payload '%s'" + " on topic %s" + ), + payload, + msg.topic, + ) + json_payload["installed_version"] = str(payload) + + if "installed_version" in json_payload: + self._attr_installed_version = json_payload["installed_version"] + + if "latest_version" in json_payload: + self._attr_latest_version = json_payload["latest_version"] + + if "title" in json_payload: + self._attr_title = json_payload["title"] + + if "release_summary" in json_payload: + self._attr_release_summary = json_payload["release_summary"] + + if "release_url" in json_payload: + self._attr_release_url = json_payload["release_url"] + + if "entity_picture" in json_payload: + self._entity_picture = json_payload["entity_picture"] + + @callback + def _handle_latest_version_received(self, msg: ReceiveMessage) -> None: + """Handle receiving latest version via MQTT.""" + latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload) + + if isinstance(latest_version, str) and latest_version != "": + self._attr_latest_version = latest_version + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" topics: dict[str, Any] = {} def add_subscription( - topics: dict[str, Any], topic: str, msg_callback: MessageCallbackType + topics: dict[str, Any], + topic: str, + msg_callback: MessageCallbackType, + tracked_attributes: set[str], ) -> None: if self._config.get(topic) is not None: topics[topic] = { "topic": self._config[topic], - "msg_callback": msg_callback, + "msg_callback": partial( + self._message_callback, msg_callback, tracked_attributes + ), + "entity_id": self.entity_id, "qos": self._config[CONF_QOS], "encoding": self._config[CONF_ENCODING] or None, } - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change( - self, + add_subscription( + topics, + CONF_STATE_TOPIC, + self._handle_state_message_received, { "_attr_installed_version", "_attr_latest_version", @@ -169,84 +244,11 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): "_entity_picture", }, ) - def handle_state_message_received(msg: ReceiveMessage) -> None: - """Handle receiving state message via MQTT.""" - payload = self._templates[CONF_VALUE_TEMPLATE](msg.payload) - - if not payload or payload == PAYLOAD_EMPTY_JSON: - _LOGGER.debug( - "Ignoring empty payload '%s' after rendering for topic %s", - payload, - msg.topic, - ) - return - - json_payload: _MqttUpdatePayloadType = {} - try: - rendered_json_payload = json_loads(payload) - if isinstance(rendered_json_payload, dict): - _LOGGER.debug( - ( - "JSON payload detected after processing payload '%s' on" - " topic %s" - ), - rendered_json_payload, - msg.topic, - ) - json_payload = cast(_MqttUpdatePayloadType, rendered_json_payload) - else: - _LOGGER.debug( - ( - "Non-dictionary JSON payload detected after processing" - " payload '%s' on topic %s" - ), - payload, - msg.topic, - ) - json_payload = {"installed_version": str(payload)} - except JSON_DECODE_EXCEPTIONS: - _LOGGER.debug( - ( - "No valid (JSON) payload detected after processing payload '%s'" - " on topic %s" - ), - payload, - msg.topic, - ) - json_payload["installed_version"] = str(payload) - - if "installed_version" in json_payload: - self._attr_installed_version = json_payload["installed_version"] - - if "latest_version" in json_payload: - self._attr_latest_version = json_payload["latest_version"] - - if "title" in json_payload: - self._attr_title = json_payload["title"] - - if "release_summary" in json_payload: - self._attr_release_summary = json_payload["release_summary"] - - if "release_url" in json_payload: - self._attr_release_url = json_payload["release_url"] - - if "entity_picture" in json_payload: - self._entity_picture = json_payload["entity_picture"] - - add_subscription(topics, CONF_STATE_TOPIC, handle_state_message_received) - - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change(self, {"_attr_latest_version"}) - def handle_latest_version_received(msg: ReceiveMessage) -> None: - """Handle receiving latest version via MQTT.""" - latest_version = self._templates[CONF_LATEST_VERSION_TEMPLATE](msg.payload) - - if isinstance(latest_version, str) and latest_version != "": - self._attr_latest_version = latest_version - add_subscription( - topics, CONF_LATEST_VERSION_TOPIC, handle_latest_version_received + topics, + CONF_LATEST_VERSION_TOPIC, + self._handle_latest_version_received, + {"_attr_latest_version"}, ) self._sub_state = subscription.async_prepare_subscribe_topics( @@ -255,7 +257,7 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_install( self, version: str | None, backup: bool, **kwargs: Any diff --git a/homeassistant/components/mqtt/vacuum.py b/homeassistant/components/mqtt/vacuum.py index 57265008025..5c8c2fd2ba5 100644 --- a/homeassistant/components/mqtt/vacuum.py +++ b/homeassistant/components/mqtt/vacuum.py @@ -8,6 +8,7 @@ from __future__ import annotations from collections.abc import Callable +from functools import partial import logging from typing import Any, cast @@ -49,12 +50,7 @@ from .const import ( CONF_STATE_TOPIC, DOMAIN, ) -from .debug_info import log_messages -from .mixins import ( - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import ReceiveMessage from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .util import valid_publish_topic @@ -322,31 +318,32 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity): self._attr_fan_speed = self._state_attrs.get(FAN_SPEED, 0) self._attr_battery_level = max(0, min(100, self._state_attrs.get(BATTERY, 0))) + @callback + def _state_message_received(self, msg: ReceiveMessage) -> None: + """Handle state MQTT message.""" + payload = json_loads_object(msg.payload) + if STATE in payload and ( + (state := payload[STATE]) in POSSIBLE_STATES or state is None + ): + self._attr_state = ( + POSSIBLE_STATES[cast(str, state)] if payload[STATE] else None + ) + del payload[STATE] + self._update_state_attributes(payload) + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" topics: dict[str, Any] = {} - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change( - self, {"_attr_battery_level", "_attr_fan_speed", "_attr_state"} - ) - def state_message_received(msg: ReceiveMessage) -> None: - """Handle state MQTT message.""" - payload = json_loads_object(msg.payload) - if STATE in payload and ( - (state := payload[STATE]) in POSSIBLE_STATES or state is None - ): - self._attr_state = ( - POSSIBLE_STATES[cast(str, state)] if payload[STATE] else None - ) - del payload[STATE] - self._update_state_attributes(payload) - if state_topic := self._config.get(CONF_STATE_TOPIC): topics["state_position_topic"] = { "topic": state_topic, - "msg_callback": state_message_received, + "msg_callback": partial( + self._message_callback, + self._state_message_received, + {"_attr_battery_level", "_attr_fan_speed", "_attr_state"}, + ), + "entity_id": self.entity_id, "qos": self._config[CONF_QOS], "encoding": self._config[CONF_ENCODING] or None, } @@ -356,7 +353,7 @@ class MqttStateVacuum(MqttEntity, StateVacuumEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def _async_publish_command(self, feature: VacuumEntityFeature) -> None: """Publish a command.""" diff --git a/homeassistant/components/mqtt/valve.py b/homeassistant/components/mqtt/valve.py index 89a60eef852..ce89c6c2daf 100644 --- a/homeassistant/components/mqtt/valve.py +++ b/homeassistant/components/mqtt/valve.py @@ -3,6 +3,7 @@ from __future__ import annotations from contextlib import suppress +from functools import partial import logging from typing import Any @@ -61,12 +62,7 @@ from .const import ( DEFAULT_RETAIN, PAYLOAD_NONE, ) -from .debug_info import log_messages -from .mixins import ( - MqttEntity, - async_setup_entity_entry_helper, - write_state_on_attr_change, -) +from .mixins import MqttEntity, async_setup_entity_entry_helper from .models import MqttCommandTemplate, MqttValueTemplate, ReceiveMessage from .schemas import MQTT_ENTITY_COMMON_SCHEMA from .util import valid_publish_topic, valid_subscribe_topic @@ -302,65 +298,63 @@ class MqttValve(MqttEntity, ValveEntity): return self._update_state(state) + @callback + def _state_message_received(self, msg: ReceiveMessage) -> None: + """Handle new MQTT state messages.""" + payload = self._value_template(msg.payload) + payload_dict: Any = None + position_payload: Any = payload + state_payload: Any = payload + + if not payload: + _LOGGER.debug("Ignoring empty state message from '%s'", msg.topic) + return + + with suppress(*JSON_DECODE_EXCEPTIONS): + payload_dict = json_loads(payload) + if isinstance(payload_dict, dict): + if self.reports_position and "position" not in payload_dict: + _LOGGER.warning( + "Missing required `position` attribute in json payload " + "on topic '%s', got: %s", + msg.topic, + payload, + ) + return + if not self.reports_position and "state" not in payload_dict: + _LOGGER.warning( + "Missing required `state` attribute in json payload " + " on topic '%s', got: %s", + msg.topic, + payload, + ) + return + position_payload = payload_dict.get("position") + state_payload = payload_dict.get("state") + + if self._config[CONF_REPORTS_POSITION]: + self._process_position_valve_update(msg, position_payload, state_payload) + else: + self._process_binary_valve_update(msg, state_payload) + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" topics = {} - @callback - @log_messages(self.hass, self.entity_id) - @write_state_on_attr_change( - self, - { - "_attr_current_valve_position", - "_attr_is_closed", - "_attr_is_closing", - "_attr_is_opening", - }, - ) - def state_message_received(msg: ReceiveMessage) -> None: - """Handle new MQTT state messages.""" - payload = self._value_template(msg.payload) - payload_dict: Any = None - position_payload: Any = payload - state_payload: Any = payload - - if not payload: - _LOGGER.debug("Ignoring empty state message from '%s'", msg.topic) - return - - with suppress(*JSON_DECODE_EXCEPTIONS): - payload_dict = json_loads(payload) - if isinstance(payload_dict, dict): - if self.reports_position and "position" not in payload_dict: - _LOGGER.warning( - "Missing required `position` attribute in json payload " - "on topic '%s', got: %s", - msg.topic, - payload, - ) - return - if not self.reports_position and "state" not in payload_dict: - _LOGGER.warning( - "Missing required `state` attribute in json payload " - " on topic '%s', got: %s", - msg.topic, - payload, - ) - return - position_payload = payload_dict.get("position") - state_payload = payload_dict.get("state") - - if self._config[CONF_REPORTS_POSITION]: - self._process_position_valve_update( - msg, position_payload, state_payload - ) - else: - self._process_binary_valve_update(msg, state_payload) - if self._config.get(CONF_STATE_TOPIC): topics["state_topic"] = { "topic": self._config.get(CONF_STATE_TOPIC), - "msg_callback": state_message_received, + "msg_callback": partial( + self._message_callback, + self._state_message_received, + { + "_attr_current_valve_position", + "_attr_is_closed", + "_attr_is_closing", + "_attr_is_opening", + }, + ), + "entity_id": self.entity_id, "qos": self._config[CONF_QOS], "encoding": self._config[CONF_ENCODING] or None, } @@ -371,7 +365,7 @@ class MqttValve(MqttEntity, ValveEntity): async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" - await subscription.async_subscribe_topics(self.hass, self._sub_state) + subscription.async_subscribe_topics_internal(self.hass, self._sub_state) async def async_open_valve(self) -> None: """Move the valve up. diff --git a/homeassistant/components/ollama/conversation.py b/homeassistant/components/ollama/conversation.py index cbec719780a..fa7a3c3797e 100644 --- a/homeassistant/components/ollama/conversation.py +++ b/homeassistant/components/ollama/conversation.py @@ -9,6 +9,7 @@ from typing import Literal import ollama from homeassistant.components import assist_pipeline, conversation +from homeassistant.components.conversation import trace from homeassistant.components.homeassistant.exposed_entities import async_should_expose from homeassistant.config_entries import ConfigEntry from homeassistant.const import MATCH_ALL @@ -138,6 +139,11 @@ class OllamaConversationEntity( ollama.Message(role=MessageRole.USER.value, content=user_input.text) ) + trace.async_conversation_trace_append( + trace.ConversationTraceEventType.AGENT_DETAIL, + {"messages": message_history.messages}, + ) + # Get response try: response = await client.chat( diff --git a/homeassistant/components/openai_conversation/config_flow.py b/homeassistant/components/openai_conversation/config_flow.py index af1ec3d2fc6..09b909b3d5e 100644 --- a/homeassistant/components/openai_conversation/config_flow.py +++ b/homeassistant/components/openai_conversation/config_flow.py @@ -31,14 +31,15 @@ from .const import ( CONF_CHAT_MODEL, CONF_MAX_TOKENS, CONF_PROMPT, + CONF_RECOMMENDED, CONF_TEMPERATURE, CONF_TOP_P, - DEFAULT_CHAT_MODEL, - DEFAULT_MAX_TOKENS, DEFAULT_PROMPT, - DEFAULT_TEMPERATURE, - DEFAULT_TOP_P, DOMAIN, + RECOMMENDED_CHAT_MODEL, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TEMPERATURE, + RECOMMENDED_TOP_P, ) _LOGGER = logging.getLogger(__name__) @@ -49,6 +50,12 @@ STEP_USER_DATA_SCHEMA = vol.Schema( } ) +RECOMMENDED_OPTIONS = { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: llm.LLM_API_ASSIST, + CONF_PROMPT: DEFAULT_PROMPT, +} + async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: """Validate the user input allows us to connect. @@ -88,7 +95,7 @@ class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN): return self.async_create_entry( title="ChatGPT", data=user_input, - options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}, + options=RECOMMENDED_OPTIONS, ) return self.async_show_form( @@ -109,16 +116,32 @@ class OpenAIOptionsFlow(OptionsFlow): def __init__(self, config_entry: ConfigEntry) -> None: """Initialize options flow.""" self.config_entry = config_entry + self.last_rendered_recommended = config_entry.options.get( + CONF_RECOMMENDED, False + ) async def async_step_init( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Manage the options.""" + options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options + if user_input is not None: - if user_input[CONF_LLM_HASS_API] == "none": - user_input.pop(CONF_LLM_HASS_API) - return self.async_create_entry(title="", data=user_input) - schema = openai_config_option_schema(self.hass, self.config_entry.options) + if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended: + if user_input[CONF_LLM_HASS_API] == "none": + user_input.pop(CONF_LLM_HASS_API) + return self.async_create_entry(title="", data=user_input) + + # Re-render the options again, now with the recommended options shown/hidden + self.last_rendered_recommended = user_input[CONF_RECOMMENDED] + + options = { + CONF_RECOMMENDED: user_input[CONF_RECOMMENDED], + CONF_PROMPT: user_input[CONF_PROMPT], + CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API], + } + + schema = openai_config_option_schema(self.hass, options) return self.async_show_form( step_id="init", data_schema=vol.Schema(schema), @@ -127,16 +150,16 @@ class OpenAIOptionsFlow(OptionsFlow): def openai_config_option_schema( hass: HomeAssistant, - options: MappingProxyType[str, Any], + options: dict[str, Any] | MappingProxyType[str, Any], ) -> dict: """Return a schema for OpenAI completion options.""" - apis: list[SelectOptionDict] = [ + hass_apis: list[SelectOptionDict] = [ SelectOptionDict( label="No control", value="none", ) ] - apis.extend( + hass_apis.extend( SelectOptionDict( label=api.name, value=api.id, @@ -144,38 +167,46 @@ def openai_config_option_schema( for api in llm.async_get_apis(hass) ) - return { + schema = { vol.Optional( CONF_PROMPT, - description={"suggested_value": options.get(CONF_PROMPT)}, - default=DEFAULT_PROMPT, + description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)}, ): TemplateSelector(), vol.Optional( CONF_LLM_HASS_API, description={"suggested_value": options.get(CONF_LLM_HASS_API)}, default="none", - ): SelectSelector(SelectSelectorConfig(options=apis)), - vol.Optional( - CONF_CHAT_MODEL, - description={ - # New key in HA 2023.4 - "suggested_value": options.get(CONF_CHAT_MODEL) - }, - default=DEFAULT_CHAT_MODEL, - ): str, - vol.Optional( - CONF_MAX_TOKENS, - description={"suggested_value": options.get(CONF_MAX_TOKENS)}, - default=DEFAULT_MAX_TOKENS, - ): int, - vol.Optional( - CONF_TOP_P, - description={"suggested_value": options.get(CONF_TOP_P)}, - default=DEFAULT_TOP_P, - ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), - vol.Optional( - CONF_TEMPERATURE, - description={"suggested_value": options.get(CONF_TEMPERATURE)}, - default=DEFAULT_TEMPERATURE, - ): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)), + ): SelectSelector(SelectSelectorConfig(options=hass_apis)), + vol.Required( + CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) + ): bool, } + + if options.get(CONF_RECOMMENDED): + return schema + + schema.update( + { + vol.Optional( + CONF_CHAT_MODEL, + description={"suggested_value": options.get(CONF_CHAT_MODEL)}, + default=RECOMMENDED_CHAT_MODEL, + ): str, + vol.Optional( + CONF_MAX_TOKENS, + description={"suggested_value": options.get(CONF_MAX_TOKENS)}, + default=RECOMMENDED_MAX_TOKENS, + ): int, + vol.Optional( + CONF_TOP_P, + description={"suggested_value": options.get(CONF_TOP_P)}, + default=RECOMMENDED_TOP_P, + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + vol.Optional( + CONF_TEMPERATURE, + description={"suggested_value": options.get(CONF_TEMPERATURE)}, + default=RECOMMENDED_TEMPERATURE, + ): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)), + } + ) + return schema diff --git a/homeassistant/components/openai_conversation/const.py b/homeassistant/components/openai_conversation/const.py index 27ef86bf918..995d80e02f1 100644 --- a/homeassistant/components/openai_conversation/const.py +++ b/homeassistant/components/openai_conversation/const.py @@ -4,13 +4,15 @@ import logging DOMAIN = "openai_conversation" LOGGER = logging.getLogger(__package__) + +CONF_RECOMMENDED = "recommended" CONF_PROMPT = "prompt" DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point.""" CONF_CHAT_MODEL = "chat_model" -DEFAULT_CHAT_MODEL = "gpt-4o" +RECOMMENDED_CHAT_MODEL = "gpt-4o" CONF_MAX_TOKENS = "max_tokens" -DEFAULT_MAX_TOKENS = 150 +RECOMMENDED_MAX_TOKENS = 150 CONF_TOP_P = "top_p" -DEFAULT_TOP_P = 1.0 +RECOMMENDED_TOP_P = 1.0 CONF_TEMPERATURE = "temperature" -DEFAULT_TEMPERATURE = 1.0 +RECOMMENDED_TEMPERATURE = 1.0 diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index a878b934317..be3b8ea9126 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -8,6 +8,7 @@ import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation +from homeassistant.components.conversation import trace from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant @@ -22,13 +23,13 @@ from .const import ( CONF_PROMPT, CONF_TEMPERATURE, CONF_TOP_P, - DEFAULT_CHAT_MODEL, - DEFAULT_MAX_TOKENS, DEFAULT_PROMPT, - DEFAULT_TEMPERATURE, - DEFAULT_TOP_P, DOMAIN, LOGGER, + RECOMMENDED_CHAT_MODEL, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TEMPERATURE, + RECOMMENDED_TOP_P, ) # Max number of back and forth with the LLM to generate a response @@ -97,15 +98,14 @@ class OpenAIConversationEntity( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" + options = self.entry.options intent_response = intent.IntentResponse(language=user_input.language) llm_api: llm.API | None = None tools: list[dict[str, Any]] | None = None - if self.entry.options.get(CONF_LLM_HASS_API): + if options.get(CONF_LLM_HASS_API): try: - llm_api = llm.async_get_api( - self.hass, self.entry.options[CONF_LLM_HASS_API] - ) + llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API]) except HomeAssistantError as err: LOGGER.error("Error getting LLM API: %s", err) intent_response.async_set_error( @@ -117,26 +117,12 @@ class OpenAIConversationEntity( ) tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] - model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) - max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) - top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) - temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) - if user_input.conversation_id in self.history: conversation_id = user_input.conversation_id messages = self.history[conversation_id] else: conversation_id = ulid.ulid_now() try: - prompt = template.Template( - self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass - ).async_render( - { - "ha_name": self.hass.config.location_name, - }, - parse_result=False, - ) - if llm_api: empty_tool_input = llm.ToolInput( tool_name="", @@ -149,11 +135,24 @@ class OpenAIConversationEntity( device_id=user_input.device_id, ) - prompt = ( - await llm_api.async_get_api_prompt(empty_tool_input) - + "\n" - + prompt + api_prompt = await llm_api.async_get_api_prompt(empty_tool_input) + + else: + api_prompt = llm.PROMPT_NO_API_CONFIGURED + + prompt = "\n".join( + ( + template.Template( + options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass + ).async_render( + { + "ha_name": self.hass.config.location_name, + }, + parse_result=False, + ), + api_prompt, ) + ) except TemplateError as err: LOGGER.error("Error rendering prompt: %s", err) @@ -170,7 +169,10 @@ class OpenAIConversationEntity( messages.append({"role": "user", "content": user_input.text}) - LOGGER.debug("Prompt for %s: %s", model, messages) + LOGGER.debug("Prompt: %s", messages) + trace.async_conversation_trace_append( + trace.ConversationTraceEventType.AGENT_DETAIL, {"messages": messages} + ) client = self.hass.data[DOMAIN][self.entry.entry_id] @@ -178,12 +180,12 @@ class OpenAIConversationEntity( for _iteration in range(MAX_TOOL_ITERATIONS): try: result = await client.chat.completions.create( - model=model, + model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), messages=messages, tools=tools, - max_tokens=max_tokens, - top_p=top_p, - temperature=temperature, + max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), + top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), + temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), user=conversation_id, ) except openai.OpenAIError as err: diff --git a/homeassistant/components/openai_conversation/strings.json b/homeassistant/components/openai_conversation/strings.json index 01060afc7f1..1e93c60b6a9 100644 --- a/homeassistant/components/openai_conversation/strings.json +++ b/homeassistant/components/openai_conversation/strings.json @@ -22,7 +22,8 @@ "max_tokens": "Maximum tokens to return in response", "temperature": "Temperature", "top_p": "Top P", - "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" + "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", + "recommended": "Recommended model settings" }, "data_description": { "prompt": "Instruct how the LLM should respond. This can be a template." diff --git a/homeassistant/components/spotify/__init__.py b/homeassistant/components/spotify/__init__.py index 8d5183a459d..632871ba36e 100644 --- a/homeassistant/components/spotify/__init__.py +++ b/homeassistant/components/spotify/__init__.py @@ -30,7 +30,6 @@ from .util import ( PLATFORMS = [Platform.MEDIA_PLAYER] - __all__ = [ "async_browse_media", "DOMAIN", @@ -50,7 +49,10 @@ class HomeAssistantSpotifyData: session: OAuth2Session -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +type SpotifyConfigEntry = ConfigEntry[HomeAssistantSpotifyData] + + +async def async_setup_entry(hass: HomeAssistant, entry: SpotifyConfigEntry) -> bool: """Set up Spotify from a config entry.""" implementation = await async_get_config_entry_implementation(hass, entry) session = OAuth2Session(hass, entry, implementation) @@ -100,8 +102,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ) await device_coordinator.async_config_entry_first_refresh() - hass.data.setdefault(DOMAIN, {}) - hass.data[DOMAIN][entry.entry_id] = HomeAssistantSpotifyData( + entry.runtime_data = HomeAssistantSpotifyData( client=spotify, current_user=current_user, devices=device_coordinator, @@ -117,6 +118,4 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload Spotify config entry.""" - if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS): - del hass.data[DOMAIN][entry.entry_id] - return unload_ok + return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) diff --git a/homeassistant/components/spotify/browse_media.py b/homeassistant/components/spotify/browse_media.py index cc8f57be1bb..a1d3d9c804a 100644 --- a/homeassistant/components/spotify/browse_media.py +++ b/homeassistant/components/spotify/browse_media.py @@ -5,7 +5,7 @@ from __future__ import annotations from enum import StrEnum from functools import partial import logging -from typing import Any +from typing import TYPE_CHECKING, Any from spotipy import Spotify import yarl @@ -22,6 +22,9 @@ from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session from .const import DOMAIN, MEDIA_PLAYER_PREFIX, MEDIA_TYPE_SHOW, PLAYABLE_MEDIA_TYPES from .util import fetch_image_url +if TYPE_CHECKING: + from . import HomeAssistantSpotifyData + BROWSE_LIMIT = 48 @@ -140,21 +143,21 @@ async def async_browse_media( # Check if caller is requesting the root nodes if media_content_type is None and media_content_id is None: - children = [] - for config_entry_id in hass.data[DOMAIN]: - config_entry = hass.config_entries.async_get_entry(config_entry_id) - assert config_entry is not None - children.append( - BrowseMedia( - title=config_entry.title, - media_class=MediaClass.APP, - media_content_id=f"{MEDIA_PLAYER_PREFIX}{config_entry_id}", - media_content_type=f"{MEDIA_PLAYER_PREFIX}library", - thumbnail="https://brands.home-assistant.io/_/spotify/logo.png", - can_play=False, - can_expand=True, - ) + config_entries = hass.config_entries.async_entries( + DOMAIN, include_disabled=False, include_ignore=False + ) + children = [ + BrowseMedia( + title=config_entry.title, + media_class=MediaClass.APP, + media_content_id=f"{MEDIA_PLAYER_PREFIX}{config_entry.entry_id}", + media_content_type=f"{MEDIA_PLAYER_PREFIX}library", + thumbnail="https://brands.home-assistant.io/_/spotify/logo.png", + can_play=False, + can_expand=True, ) + for config_entry in config_entries + ] return BrowseMedia( title="Spotify", media_class=MediaClass.APP, @@ -171,9 +174,15 @@ async def async_browse_media( # Check for config entry specifier, and extract Spotify URI parsed_url = yarl.URL(media_content_id) - if (info := hass.data[DOMAIN].get(parsed_url.host)) is None: + + if ( + parsed_url.host is None + or (entry := hass.config_entries.async_get_entry(parsed_url.host)) is None + or not isinstance(entry.runtime_data, HomeAssistantSpotifyData) + ): raise BrowseError("Invalid Spotify account specified") media_content_id = parsed_url.name + info = entry.runtime_data result = await async_browse_media_internal( hass, diff --git a/homeassistant/components/spotify/media_player.py b/homeassistant/components/spotify/media_player.py index fc7a084939a..fe9614374f7 100644 --- a/homeassistant/components/spotify/media_player.py +++ b/homeassistant/components/spotify/media_player.py @@ -22,7 +22,6 @@ from homeassistant.components.media_player import ( MediaType, RepeatMode, ) -from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_ID from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError @@ -30,7 +29,7 @@ from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util.dt import utcnow -from . import HomeAssistantSpotifyData +from . import HomeAssistantSpotifyData, SpotifyConfigEntry from .browse_media import async_browse_media_internal from .const import DOMAIN, MEDIA_PLAYER_PREFIX, PLAYABLE_MEDIA_TYPES, SPOTIFY_SCOPES from .util import fetch_image_url @@ -70,12 +69,12 @@ SPOTIFY_DJ_PLAYLIST = {"uri": "spotify:playlist:37i9dQZF1EYkqdzj48dyYq", "name": async def async_setup_entry( hass: HomeAssistant, - entry: ConfigEntry, + entry: SpotifyConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up Spotify based on a config entry.""" spotify = SpotifyMediaPlayer( - hass.data[DOMAIN][entry.entry_id], + entry.runtime_data, entry.data[CONF_ID], entry.title, ) diff --git a/homeassistant/components/switcher_kis/__init__.py b/homeassistant/components/switcher_kis/__init__.py index abc9091742a..60b3b18b0b0 100644 --- a/homeassistant/components/switcher_kis/__init__.py +++ b/homeassistant/components/switcher_kis/__init__.py @@ -4,15 +4,14 @@ from __future__ import annotations import logging +from aioswitcher.bridge import SwitcherBridge from aioswitcher.device import SwitcherBase from homeassistant.config_entries import ConfigEntry from homeassistant.const import EVENT_HOMEASSISTANT_STOP, Platform from homeassistant.core import Event, HomeAssistant, callback -from .const import DATA_DEVICE, DOMAIN from .coordinator import SwitcherDataUpdateCoordinator -from .utils import async_start_bridge, async_stop_bridge PLATFORMS = [ Platform.BUTTON, @@ -25,20 +24,20 @@ PLATFORMS = [ _LOGGER = logging.getLogger(__name__) -async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +type SwitcherConfigEntry = ConfigEntry[dict[str, SwitcherDataUpdateCoordinator]] + + +async def async_setup_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> bool: """Set up Switcher from a config entry.""" - hass.data.setdefault(DOMAIN, {}) - hass.data[DOMAIN][DATA_DEVICE] = {} @callback def on_device_data_callback(device: SwitcherBase) -> None: """Use as a callback for device data.""" + coordinators = entry.runtime_data + # Existing device update device data - if device.device_id in hass.data[DOMAIN][DATA_DEVICE]: - coordinator: SwitcherDataUpdateCoordinator = hass.data[DOMAIN][DATA_DEVICE][ - device.device_id - ] + if coordinator := coordinators.get(device.device_id): coordinator.async_set_updated_data(device) return @@ -52,18 +51,21 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: device.device_type.hex_rep, ) - coordinator = hass.data[DOMAIN][DATA_DEVICE][device.device_id] = ( - SwitcherDataUpdateCoordinator(hass, entry, device) - ) + coordinator = SwitcherDataUpdateCoordinator(hass, entry, device) coordinator.async_setup() + coordinators[device.device_id] = coordinator # Must be ready before dispatcher is called await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) - await async_start_bridge(hass, on_device_data_callback) + entry.runtime_data = {} + bridge = SwitcherBridge(on_device_data_callback) + await bridge.start() - async def stop_bridge(event: Event) -> None: - await async_stop_bridge(hass) + async def stop_bridge(event: Event | None = None) -> None: + await bridge.stop() + + entry.async_on_unload(stop_bridge) entry.async_on_unload( hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_bridge) @@ -72,12 +74,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return True -async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: +async def async_unload_entry(hass: HomeAssistant, entry: SwitcherConfigEntry) -> bool: """Unload a config entry.""" - await async_stop_bridge(hass) - - unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) - if unload_ok: - hass.data[DOMAIN].pop(DATA_DEVICE) - - return unload_ok + return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) diff --git a/homeassistant/components/switcher_kis/button.py b/homeassistant/components/switcher_kis/button.py index b787043f86c..9454dcabc49 100644 --- a/homeassistant/components/switcher_kis/button.py +++ b/homeassistant/components/switcher_kis/button.py @@ -15,7 +15,6 @@ from aioswitcher.api.remotes import SwitcherBreezeRemote from aioswitcher.device import DeviceCategory from homeassistant.components.button import ButtonEntity, ButtonEntityDescription -from homeassistant.config_entries import ConfigEntry from homeassistant.const import EntityCategory from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError @@ -25,6 +24,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import CoordinatorEntity +from . import SwitcherConfigEntry from .const import SIGNAL_DEVICE_ADD from .coordinator import SwitcherDataUpdateCoordinator from .utils import get_breeze_remote_manager @@ -78,7 +78,7 @@ THERMOSTAT_BUTTONS = [ async def async_setup_entry( hass: HomeAssistant, - config_entry: ConfigEntry, + config_entry: SwitcherConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up Switcher button from config entry.""" diff --git a/homeassistant/components/switcher_kis/climate.py b/homeassistant/components/switcher_kis/climate.py index efcb9c81f0a..9797873c73b 100644 --- a/homeassistant/components/switcher_kis/climate.py +++ b/homeassistant/components/switcher_kis/climate.py @@ -25,7 +25,6 @@ from homeassistant.components.climate import ( ClimateEntityFeature, HVACMode, ) -from homeassistant.config_entries import ConfigEntry from homeassistant.const import ATTR_TEMPERATURE, UnitOfTemperature from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError @@ -35,6 +34,7 @@ from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import CoordinatorEntity +from . import SwitcherConfigEntry from .const import SIGNAL_DEVICE_ADD from .coordinator import SwitcherDataUpdateCoordinator from .utils import get_breeze_remote_manager @@ -61,7 +61,7 @@ HA_TO_DEVICE_FAN = {value: key for key, value in DEVICE_FAN_TO_HA.items()} async def async_setup_entry( hass: HomeAssistant, - config_entry: ConfigEntry, + config_entry: SwitcherConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up Switcher climate from config entry.""" diff --git a/homeassistant/components/switcher_kis/const.py b/homeassistant/components/switcher_kis/const.py index 76eb2a3e497..9edc69e4946 100644 --- a/homeassistant/components/switcher_kis/const.py +++ b/homeassistant/components/switcher_kis/const.py @@ -2,9 +2,6 @@ DOMAIN = "switcher_kis" -DATA_BRIDGE = "bridge" -DATA_DEVICE = "device" - DISCOVERY_TIME_SEC = 12 SIGNAL_DEVICE_ADD = "switcher_device_add" diff --git a/homeassistant/components/switcher_kis/diagnostics.py b/homeassistant/components/switcher_kis/diagnostics.py index 441f45198a2..a81e3e25bb9 100644 --- a/homeassistant/components/switcher_kis/diagnostics.py +++ b/homeassistant/components/switcher_kis/diagnostics.py @@ -6,24 +6,23 @@ from dataclasses import asdict from typing import Any from homeassistant.components.diagnostics import async_redact_data -from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant -from .const import DATA_DEVICE, DOMAIN +from . import SwitcherConfigEntry TO_REDACT = {"device_id", "device_key", "ip_address", "mac_address"} async def async_get_config_entry_diagnostics( - hass: HomeAssistant, entry: ConfigEntry + hass: HomeAssistant, entry: SwitcherConfigEntry ) -> dict[str, Any]: """Return diagnostics for a config entry.""" - devices = hass.data[DOMAIN][DATA_DEVICE] + coordinators = entry.runtime_data return async_redact_data( { "entry": entry.as_dict(), - "devices": [asdict(devices[d].data) for d in devices], + "devices": [asdict(coordinators[d].data) for d in coordinators], }, TO_REDACT, ) diff --git a/homeassistant/components/switcher_kis/utils.py b/homeassistant/components/switcher_kis/utils.py index 79ac565a737..ad23d51e44d 100644 --- a/homeassistant/components/switcher_kis/utils.py +++ b/homeassistant/components/switcher_kis/utils.py @@ -3,9 +3,7 @@ from __future__ import annotations import asyncio -from collections.abc import Callable import logging -from typing import Any from aioswitcher.api.remotes import SwitcherBreezeRemoteManager from aioswitcher.bridge import SwitcherBase, SwitcherBridge @@ -13,29 +11,11 @@ from aioswitcher.bridge import SwitcherBase, SwitcherBridge from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import singleton -from .const import DATA_BRIDGE, DISCOVERY_TIME_SEC, DOMAIN +from .const import DISCOVERY_TIME_SEC _LOGGER = logging.getLogger(__name__) -async def async_start_bridge( - hass: HomeAssistant, on_device_callback: Callable[[SwitcherBase], Any] -) -> None: - """Start switcher UDP bridge.""" - bridge = hass.data[DOMAIN][DATA_BRIDGE] = SwitcherBridge(on_device_callback) - _LOGGER.debug("Starting Switcher bridge") - await bridge.start() - - -async def async_stop_bridge(hass: HomeAssistant) -> None: - """Stop switcher UDP bridge.""" - bridge: SwitcherBridge = hass.data[DOMAIN].get(DATA_BRIDGE) - if bridge is not None: - _LOGGER.debug("Stopping Switcher bridge") - await bridge.stop() - hass.data[DOMAIN].pop(DATA_BRIDGE) - - async def async_has_devices(hass: HomeAssistant) -> bool: """Discover Switcher devices.""" _LOGGER.debug("Starting discovery") diff --git a/homeassistant/components/teslemetry/__init__.py b/homeassistant/components/teslemetry/__init__.py index a425a26b6da..af2276dbcda 100644 --- a/homeassistant/components/teslemetry/__init__.py +++ b/homeassistant/components/teslemetry/__init__.py @@ -30,6 +30,7 @@ PLATFORMS: Final = [ Platform.BINARY_SENSOR, Platform.CLIMATE, Platform.COVER, + Platform.DEVICE_TRACKER, Platform.LOCK, Platform.SELECT, Platform.SENSOR, diff --git a/homeassistant/components/teslemetry/device_tracker.py b/homeassistant/components/teslemetry/device_tracker.py new file mode 100644 index 00000000000..afd947ab3b3 --- /dev/null +++ b/homeassistant/components/teslemetry/device_tracker.py @@ -0,0 +1,85 @@ +"""Device tracker platform for Teslemetry integration.""" + +from __future__ import annotations + +from homeassistant.components.device_tracker import SourceType +from homeassistant.components.device_tracker.config_entry import TrackerEntity +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .entity import TeslemetryVehicleEntity +from .models import TeslemetryVehicleData + + +async def async_setup_entry( + hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback +) -> None: + """Set up the Teslemetry device tracker platform from a config entry.""" + + async_add_entities( + klass(vehicle) + for klass in ( + TeslemetryDeviceTrackerLocationEntity, + TeslemetryDeviceTrackerRouteEntity, + ) + for vehicle in entry.runtime_data.vehicles + ) + + +class TeslemetryDeviceTrackerEntity(TeslemetryVehicleEntity, TrackerEntity): + """Base class for Teslemetry tracker entities.""" + + lat_key: str + lon_key: str + + def __init__( + self, + vehicle: TeslemetryVehicleData, + ) -> None: + """Initialize the device tracker.""" + super().__init__(vehicle, self.key) + + def _async_update_attrs(self) -> None: + """Update the attributes of the device tracker.""" + + self._attr_available = ( + self.get(self.lat_key, False) is not None + and self.get(self.lon_key, False) is not None + ) + + @property + def latitude(self) -> float | None: + """Return latitude value of the device.""" + return self.get(self.lat_key) + + @property + def longitude(self) -> float | None: + """Return longitude value of the device.""" + return self.get(self.lon_key) + + @property + def source_type(self) -> SourceType: + """Return the source type of the device tracker.""" + return SourceType.GPS + + +class TeslemetryDeviceTrackerLocationEntity(TeslemetryDeviceTrackerEntity): + """Vehicle location device tracker class.""" + + key = "location" + lat_key = "drive_state_latitude" + lon_key = "drive_state_longitude" + + +class TeslemetryDeviceTrackerRouteEntity(TeslemetryDeviceTrackerEntity): + """Vehicle navigation device tracker class.""" + + key = "route" + lat_key = "drive_state_active_route_latitude" + lon_key = "drive_state_active_route_longitude" + + @property + def location_name(self) -> str | None: + """Return a location name for the current location of the device.""" + return self.get("drive_state_active_route_destination") diff --git a/homeassistant/components/teslemetry/icons.json b/homeassistant/components/teslemetry/icons.json index 0236bc41c23..3224fee603b 100644 --- a/homeassistant/components/teslemetry/icons.json +++ b/homeassistant/components/teslemetry/icons.json @@ -109,6 +109,7 @@ "off": "mdi:car-seat" } }, + "components_customer_preferred_export_rule": { "default": "mdi:transmission-tower", "state": { @@ -126,6 +127,14 @@ } } }, + "device_tracker": { + "location": { + "default": "mdi:map-marker" + }, + "route": { + "default": "mdi:routes" + } + }, "cover": { "charge_state_charge_port_door_open": { "default": "mdi:ev-plug-ccs2" diff --git a/homeassistant/components/teslemetry/strings.json b/homeassistant/components/teslemetry/strings.json index 322a27929e5..e41fbbd4507 100644 --- a/homeassistant/components/teslemetry/strings.json +++ b/homeassistant/components/teslemetry/strings.json @@ -111,6 +111,14 @@ } } }, + "device_tracker": { + "location": { + "name": "Location" + }, + "route": { + "name": "Route" + } + }, "lock": { "charge_state_charge_port_latch": { "name": "Charge cable lock" diff --git a/homeassistant/components/velbus/manifest.json b/homeassistant/components/velbus/manifest.json index 6f817a23325..f778533cad8 100644 --- a/homeassistant/components/velbus/manifest.json +++ b/homeassistant/components/velbus/manifest.json @@ -13,7 +13,7 @@ "velbus-packet", "velbus-protocol" ], - "requirements": ["velbus-aio==2024.4.1"], + "requirements": ["velbus-aio==2024.5.1"], "usb": [ { "vid": "10CF", diff --git a/homeassistant/components/vera/__init__.py b/homeassistant/components/vera/__init__.py index 5340863fa18..722a6b86d4b 100644 --- a/homeassistant/components/vera/__init__.py +++ b/homeassistant/components/vera/__init__.py @@ -4,7 +4,6 @@ from __future__ import annotations import asyncio from collections import defaultdict -from collections.abc import Awaitable import logging from typing import Any @@ -157,16 +156,16 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: - """Unload Withings config entry.""" + """Unload vera config entry.""" controller_data: ControllerData = get_controller_data(hass, config_entry) - - tasks: list[Awaitable] = [ - hass.config_entries.async_forward_entry_unload(config_entry, platform) - for platform in get_configured_platforms(controller_data) - ] - tasks.append(hass.async_add_executor_job(controller_data.controller.stop)) - await asyncio.gather(*tasks) - + await asyncio.gather( + *( + hass.config_entries.async_unload_platforms( + config_entry, get_configured_platforms(controller_data) + ), + hass.async_add_executor_job(controller_data.controller.stop), + ) + ) return True diff --git a/homeassistant/components/zha/__init__.py b/homeassistant/components/zha/__init__.py index de761138ce1..ed74cde47e1 100644 --- a/homeassistant/components/zha/__init__.py +++ b/homeassistant/components/zha/__init__.py @@ -1,6 +1,5 @@ """Support for Zigbee Home Automation devices.""" -import asyncio import contextlib import copy import logging @@ -238,12 +237,7 @@ async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> websocket_api.async_unload_api(hass) # our components don't have unload methods so no need to look at return values - await asyncio.gather( - *( - hass.config_entries.async_forward_entry_unload(config_entry, platform) - for platform in PLATFORMS - ) - ) + await hass.config_entries.async_unload_platforms(config_entry, PLATFORMS) return True diff --git a/homeassistant/components/zwave_js/__init__.py b/homeassistant/components/zwave_js/__init__.py index e0b0e3cd370..efd9ab717ad 100644 --- a/homeassistant/components/zwave_js/__init__.py +++ b/homeassistant/components/zwave_js/__init__.py @@ -4,7 +4,6 @@ from __future__ import annotations import asyncio from collections import defaultdict -from collections.abc import Coroutine from contextlib import suppress import logging from typing import Any @@ -958,14 +957,12 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" client: ZwaveClient = entry.runtime_data[DATA_CLIENT] driver_events: DriverEvents = entry.runtime_data[DATA_DRIVER_EVENTS] - - tasks: list[Coroutine] = [ - hass.config_entries.async_forward_entry_unload(entry, platform) + platforms = [ + platform for platform, task in driver_events.platform_setup_tasks.items() if not task.cancel() ] - - unload_ok = all(await asyncio.gather(*tasks)) if tasks else True + unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms) if client.connected and client.driver: await async_disable_server_logging_if_needed(hass, entry, client.driver) diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index cde644a7641..08125acc0da 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -3,12 +3,16 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any import voluptuous as vol from homeassistant.components.climate.intent import INTENT_GET_TEMPERATURE +from homeassistant.components.conversation.trace import ( + ConversationTraceEventType, + async_conversation_trace_append, +) from homeassistant.components.weather.intent import INTENT_GET_WEATHER from homeassistant.core import Context, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError @@ -116,6 +120,10 @@ class API(ABC): async def async_call_tool(self, tool_input: ToolInput) -> JsonObjectType: """Call a LLM tool, validate args and return the response.""" + async_conversation_trace_append( + ConversationTraceEventType.LLM_TOOL_CALL, asdict(tool_input) + ) + for tool in self.async_get_tools(): if tool.name == tool_input.tool_name: break @@ -191,7 +199,10 @@ class AssistAPI(API): async def async_get_api_prompt(self, tool_input: ToolInput) -> str: """Return the prompt for the API.""" - prompt = "Call the intent tools to control Home Assistant. Just pass the name to the intent." + prompt = ( + "Call the intent tools to control Home Assistant. " + "Just pass the name to the intent." + ) if tool_input.device_id: device_reg = device_registry.async_get(self.hass) device = device_reg.async_get(tool_input.device_id) diff --git a/requirements_all.txt b/requirements_all.txt index f0e72b2398e..6baa552b0f6 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1821,7 +1821,7 @@ pyegps==0.2.5 pyenphase==1.20.3 # homeassistant.components.envisalink -pyenvisalink==4.6 +pyenvisalink==4.7 # homeassistant.components.ephember pyephember==0.3.1 @@ -2817,7 +2817,7 @@ vallox-websocket-api==5.1.1 vehicle==2.2.1 # homeassistant.components.velbus -velbus-aio==2024.4.1 +velbus-aio==2024.5.1 # homeassistant.components.venstar venstarcolortouch==0.19 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 1314534700d..ca926fb99ce 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -2185,7 +2185,7 @@ vallox-websocket-api==5.1.1 vehicle==2.2.1 # homeassistant.components.velbus -velbus-aio==2024.4.1 +velbus-aio==2024.5.1 # homeassistant.components.venstar venstarcolortouch==0.19 diff --git a/script/hassfest/manifest.py b/script/hassfest/manifest.py index 54ae65e6727..e92ec00b117 100644 --- a/script/hassfest/manifest.py +++ b/script/hassfest/manifest.py @@ -117,7 +117,6 @@ NO_IOT_CLASS = [ # https://github.com/home-assistant/developers.home-assistant/pull/1512 NO_DIAGNOSTICS = [ "dlna_dms", - "fronius", "gdacs", "geonetnz_quakes", "google_assistant_sdk", diff --git a/tests/components/conversation/test_entity.py b/tests/components/conversation/test_entity.py index c84f94c4aa4..109c0ed361f 100644 --- a/tests/components/conversation/test_entity.py +++ b/tests/components/conversation/test_entity.py @@ -2,7 +2,9 @@ from unittest.mock import patch +from homeassistant.components import conversation from homeassistant.core import Context, HomeAssistant, State +from homeassistant.helpers import intent from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -31,6 +33,11 @@ async def test_state_set_and_restore(hass: HomeAssistant) -> None: ) as mock_process, patch("homeassistant.util.dt.utcnow", return_value=now), ): + intent_response = intent.IntentResponse(language="en") + intent_response.async_set_speech("response text") + mock_process.return_value = conversation.ConversationResult( + response=intent_response, + ) await hass.services.async_call( "conversation", "process", diff --git a/tests/components/conversation/test_trace.py b/tests/components/conversation/test_trace.py new file mode 100644 index 00000000000..c586eb8865d --- /dev/null +++ b/tests/components/conversation/test_trace.py @@ -0,0 +1,80 @@ +"""Test for the conversation traces.""" + +from unittest.mock import patch + +import pytest + +from homeassistant.components import conversation +from homeassistant.components.conversation import trace +from homeassistant.core import Context, HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.setup import async_setup_component + + +@pytest.fixture +async def init_components(hass: HomeAssistant): + """Initialize relevant components with empty configs.""" + assert await async_setup_component(hass, "homeassistant", {}) + assert await async_setup_component(hass, "conversation", {}) + assert await async_setup_component(hass, "intent", {}) + + +async def test_converation_trace( + hass: HomeAssistant, + init_components: None, + sl_setup: None, +) -> None: + """Test tracing a conversation.""" + await conversation.async_converse( + hass, "add apples to my shopping list", None, Context() + ) + + traces = trace.async_get_traces() + assert traces + last_trace = traces[-1].as_dict() + assert last_trace.get("events") + assert len(last_trace.get("events")) == 1 + trace_event = last_trace["events"][0] + assert ( + trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS + ) + assert trace_event.get("data") + assert trace_event["data"].get("text") == "add apples to my shopping list" + assert last_trace.get("result") + assert ( + last_trace["result"] + .get("response", {}) + .get("speech", {}) + .get("plain", {}) + .get("speech") + == "Added apples" + ) + + +async def test_converation_trace_error( + hass: HomeAssistant, + init_components: None, + sl_setup: None, +) -> None: + """Test tracing a conversation.""" + with ( + patch( + "homeassistant.components.conversation.default_agent.DefaultAgent.async_process", + side_effect=HomeAssistantError("Failed to talk to agent"), + ), + pytest.raises(HomeAssistantError), + ): + await conversation.async_converse( + hass, "add apples to my shopping list", None, Context() + ) + + traces = trace.async_get_traces() + assert traces + last_trace = traces[-1].as_dict() + assert last_trace.get("events") + assert len(last_trace.get("events")) == 1 + trace_event = last_trace["events"][0] + assert ( + trace_event.get("event_type") == trace.ConversationTraceEventType.ASYNC_PROCESS + ) + assert last_trace.get("error") == "Failed to talk to agent" diff --git a/tests/components/fronius/__init__.py b/tests/components/fronius/__init__.py index f1630d6cd7e..6cefae734a0 100644 --- a/tests/components/fronius/__init__.py +++ b/tests/components/fronius/__init__.py @@ -25,6 +25,7 @@ async def setup_fronius_integration( """Create the Fronius integration.""" entry = MockConfigEntry( domain=DOMAIN, + entry_id="f1e2b9837e8adaed6fa682acaa216fd8", unique_id=unique_id, # has to match mocked logger unique_id data={ CONF_HOST: MOCK_HOST, diff --git a/tests/components/fronius/snapshots/test_diagnostics.ambr b/tests/components/fronius/snapshots/test_diagnostics.ambr new file mode 100644 index 00000000000..f23d63a58e3 --- /dev/null +++ b/tests/components/fronius/snapshots/test_diagnostics.ambr @@ -0,0 +1,370 @@ +# serializer version: 1 +# name: test_diagnostics + dict({ + 'config_entry': dict({ + 'data': dict({ + 'host': 'http://fronius', + 'is_logger': True, + }), + 'disabled_by': None, + 'domain': 'fronius', + 'entry_id': 'f1e2b9837e8adaed6fa682acaa216fd8', + 'minor_version': 1, + 'options': dict({ + }), + 'pref_disable_new_entities': False, + 'pref_disable_polling': False, + 'source': 'user', + 'title': 'Mock Title', + 'unique_id': '**REDACTED**', + 'version': 1, + }), + 'coordinators': dict({ + 'inverters': dict({ + '1': dict({ + 'current_ac': dict({ + 'unit': 'A', + 'value': 5.19, + }), + 'current_dc': dict({ + 'unit': 'A', + 'value': 2.19, + }), + 'energy_day': dict({ + 'unit': 'Wh', + 'value': 1113, + }), + 'energy_total': dict({ + 'unit': 'Wh', + 'value': 44188000, + }), + 'energy_year': dict({ + 'unit': 'Wh', + 'value': 25508798, + }), + 'error_code': dict({ + 'value': 0, + }), + 'frequency_ac': dict({ + 'unit': 'Hz', + 'value': 49.94, + }), + 'led_color': dict({ + 'value': 2, + }), + 'led_state': dict({ + 'value': 0, + }), + 'power_ac': dict({ + 'unit': 'W', + 'value': 1190, + }), + 'status': dict({ + 'Code': 0, + 'Reason': '', + 'UserMessage': '', + }), + 'status_code': dict({ + 'value': 7, + }), + 'timestamp': dict({ + 'value': '2021-10-07T10:01:17+02:00', + }), + 'voltage_ac': dict({ + 'unit': 'V', + 'value': 227.9, + }), + 'voltage_dc': dict({ + 'unit': 'V', + 'value': 518, + }), + }), + }), + 'logger': dict({ + 'system': dict({ + 'cash_factor': dict({ + 'unit': 'EUR/kWh', + 'value': 0.07800000160932541, + }), + 'co2_factor': dict({ + 'unit': 'kg/kWh', + 'value': 0.5299999713897705, + }), + 'delivery_factor': dict({ + 'unit': 'EUR/kWh', + 'value': 0.15000000596046448, + }), + 'hardware_platform': dict({ + 'value': 'wilma', + }), + 'hardware_version': dict({ + 'value': '2.4E', + }), + 'product_type': dict({ + 'value': 'fronius-datamanager-card', + }), + 'software_version': dict({ + 'value': '3.18.7-1', + }), + 'status': dict({ + 'Code': 0, + 'Reason': '', + 'UserMessage': '', + }), + 'time_zone': dict({ + 'value': 'CEST', + }), + 'time_zone_location': dict({ + 'value': 'Vienna', + }), + 'timestamp': dict({ + 'value': '2021-10-06T23:56:32+02:00', + }), + 'unique_identifier': '**REDACTED**', + 'utc_offset': dict({ + 'value': 7200, + }), + }), + }), + 'meter': dict({ + '0': dict({ + 'current_ac_phase_1': dict({ + 'unit': 'A', + 'value': 7.755, + }), + 'current_ac_phase_2': dict({ + 'unit': 'A', + 'value': 6.68, + }), + 'current_ac_phase_3': dict({ + 'unit': 'A', + 'value': 10.102, + }), + 'enable': dict({ + 'value': 1, + }), + 'energy_reactive_ac_consumed': dict({ + 'unit': 'VArh', + 'value': 59960790, + }), + 'energy_reactive_ac_produced': dict({ + 'unit': 'VArh', + 'value': 723160, + }), + 'energy_real_ac_minus': dict({ + 'unit': 'Wh', + 'value': 35623065, + }), + 'energy_real_ac_plus': dict({ + 'unit': 'Wh', + 'value': 15303334, + }), + 'energy_real_consumed': dict({ + 'unit': 'Wh', + 'value': 15303334, + }), + 'energy_real_produced': dict({ + 'unit': 'Wh', + 'value': 35623065, + }), + 'frequency_phase_average': dict({ + 'unit': 'Hz', + 'value': 50, + }), + 'manufacturer': dict({ + 'value': 'Fronius', + }), + 'meter_location': dict({ + 'value': 0, + }), + 'model': dict({ + 'value': 'Smart Meter 63A', + }), + 'power_apparent': dict({ + 'unit': 'VA', + 'value': 5592.57, + }), + 'power_apparent_phase_1': dict({ + 'unit': 'VA', + 'value': 1772.793, + }), + 'power_apparent_phase_2': dict({ + 'unit': 'VA', + 'value': 1527.048, + }), + 'power_apparent_phase_3': dict({ + 'unit': 'VA', + 'value': 2333.562, + }), + 'power_factor': dict({ + 'value': 1, + }), + 'power_factor_phase_1': dict({ + 'value': -0.99, + }), + 'power_factor_phase_2': dict({ + 'value': -0.99, + }), + 'power_factor_phase_3': dict({ + 'value': 0.99, + }), + 'power_reactive': dict({ + 'unit': 'VAr', + 'value': 2.87, + }), + 'power_reactive_phase_1': dict({ + 'unit': 'VAr', + 'value': 51.48, + }), + 'power_reactive_phase_2': dict({ + 'unit': 'VAr', + 'value': 115.63, + }), + 'power_reactive_phase_3': dict({ + 'unit': 'VAr', + 'value': -164.24, + }), + 'power_real': dict({ + 'unit': 'W', + 'value': 5592.57, + }), + 'power_real_phase_1': dict({ + 'unit': 'W', + 'value': 1765.55, + }), + 'power_real_phase_2': dict({ + 'unit': 'W', + 'value': 1515.8, + }), + 'power_real_phase_3': dict({ + 'unit': 'W', + 'value': 2311.22, + }), + 'serial': '**REDACTED**', + 'visible': dict({ + 'value': 1, + }), + 'voltage_ac_phase_1': dict({ + 'unit': 'V', + 'value': 228.6, + }), + 'voltage_ac_phase_2': dict({ + 'unit': 'V', + 'value': 228.6, + }), + 'voltage_ac_phase_3': dict({ + 'unit': 'V', + 'value': 231, + }), + 'voltage_ac_phase_to_phase_12': dict({ + 'unit': 'V', + 'value': 395.9, + }), + 'voltage_ac_phase_to_phase_23': dict({ + 'unit': 'V', + 'value': 398, + }), + 'voltage_ac_phase_to_phase_31': dict({ + 'unit': 'V', + 'value': 398, + }), + }), + }), + 'ohmpilot': None, + 'power_flow': dict({ + 'power_flow': dict({ + 'energy_day': dict({ + 'unit': 'Wh', + 'value': 1101.7000732421875, + }), + 'energy_total': dict({ + 'unit': 'Wh', + 'value': 44188000, + }), + 'energy_year': dict({ + 'unit': 'Wh', + 'value': 25508788, + }), + 'meter_location': dict({ + 'value': 'grid', + }), + 'meter_mode': dict({ + 'value': 'meter', + }), + 'power_battery': dict({ + 'unit': 'W', + 'value': None, + }), + 'power_grid': dict({ + 'unit': 'W', + 'value': 1703.74, + }), + 'power_load': dict({ + 'unit': 'W', + 'value': -2814.74, + }), + 'power_photovoltaics': dict({ + 'unit': 'W', + 'value': 1111, + }), + 'relative_autonomy': dict({ + 'unit': '%', + 'value': 39.4707859340472, + }), + 'relative_self_consumption': dict({ + 'unit': '%', + 'value': 100, + }), + 'status': dict({ + 'Code': 0, + 'Reason': '', + 'UserMessage': '', + }), + 'timestamp': dict({ + 'value': '2021-10-07T10:00:43+02:00', + }), + }), + }), + 'storage': None, + }), + 'inverter_info': dict({ + 'inverters': list([ + dict({ + 'custom_name': dict({ + 'value': 'Symo 20', + }), + 'device_id': dict({ + 'value': '1', + }), + 'device_type': dict({ + 'manufacturer': 'Fronius', + 'model': 'Symo 20.0-3-M', + 'value': 121, + }), + 'error_code': dict({ + 'value': 0, + }), + 'pv_power': dict({ + 'unit': 'W', + 'value': 23100, + }), + 'show': dict({ + 'value': 1, + }), + 'status_code': dict({ + 'value': 7, + }), + 'unique_id': '**REDACTED**', + }), + ]), + 'status': dict({ + 'Code': 0, + 'Reason': '', + 'UserMessage': '', + }), + 'timestamp': dict({ + 'value': '2021-10-07T13:41:00+02:00', + }), + }), + }) +# --- diff --git a/tests/components/fronius/test_diagnostics.py b/tests/components/fronius/test_diagnostics.py new file mode 100644 index 00000000000..7d8a49dcb7d --- /dev/null +++ b/tests/components/fronius/test_diagnostics.py @@ -0,0 +1,31 @@ +"""Tests for the diagnostics data provided by the KNX integration.""" + +from syrupy import SnapshotAssertion + +from homeassistant.core import HomeAssistant + +from . import mock_responses, setup_fronius_integration + +from tests.components.diagnostics import get_diagnostics_for_config_entry +from tests.test_util.aiohttp import AiohttpClientMocker +from tests.typing import ClientSessionGenerator + + +async def test_diagnostics( + hass: HomeAssistant, + hass_client: ClientSessionGenerator, + aioclient_mock: AiohttpClientMocker, + snapshot: SnapshotAssertion, +) -> None: + """Test diagnostics.""" + mock_responses(aioclient_mock) + entry = await setup_fronius_integration(hass) + + assert ( + await get_diagnostics_for_config_entry( + hass, + hass_client, + entry, + ) + == snapshot + ) diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr index ebc918bbf31..6d37c1d1823 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -1,4 +1,114 @@ # serializer version: 1 +# name: test_chat_history + list([ + tuple( + '', + tuple( + ), + dict({ + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 1.0, + 'top_k': 64, + 'top_p': 0.95, + }), + 'model_name': 'models/gemini-1.5-flash-latest', + 'safety_settings': dict({ + 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', + 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', + 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', + 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', + }), + 'tools': None, + }), + ), + tuple( + '().start_chat', + tuple( + ), + dict({ + 'history': list([ + dict({ + 'parts': ''' + Answer in plain text. Keep it simple and to the point. + Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. + ''', + 'role': 'user', + }), + dict({ + 'parts': 'Ok', + 'role': 'model', + }), + ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + '1st user request', + ), + dict({ + }), + ), + tuple( + '', + tuple( + ), + dict({ + 'generation_config': dict({ + 'max_output_tokens': 150, + 'temperature': 1.0, + 'top_k': 64, + 'top_p': 0.95, + }), + 'model_name': 'models/gemini-1.5-flash-latest', + 'safety_settings': dict({ + 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', + 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', + 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', + 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', + }), + 'tools': None, + }), + ), + tuple( + '().start_chat', + tuple( + ), + dict({ + 'history': list([ + dict({ + 'parts': ''' + Answer in plain text. Keep it simple and to the point. + Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. + ''', + 'role': 'user', + }), + dict({ + 'parts': 'Ok', + 'role': 'model', + }), + dict({ + 'parts': '1st user request', + 'role': 'user', + }), + dict({ + 'parts': '1st model response', + 'role': 'model', + }), + ]), + }), + ), + tuple( + '().start_chat().send_message_async', + tuple( + '2nd user request', + ), + dict({ + }), + ), + ]) +# --- # name: test_default_prompt[config_entry_options0-None] list([ tuple( @@ -14,10 +124,10 @@ }), 'model_name': 'models/gemini-1.5-flash-latest', 'safety_settings': dict({ - 'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', - 'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', - 'HATE': 'BLOCK_LOW_AND_ABOVE', - 'SEXUAL': 'BLOCK_LOW_AND_ABOVE', + 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', + 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', + 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', + 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', }), 'tools': None, }), @@ -29,7 +139,10 @@ dict({ 'history': list([ dict({ - 'parts': 'Answer in plain text. Keep it simple and to the point.', + 'parts': ''' + Answer in plain text. Keep it simple and to the point. + Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. + ''', 'role': 'user', }), dict({ @@ -64,10 +177,10 @@ }), 'model_name': 'models/gemini-1.5-flash-latest', 'safety_settings': dict({ - 'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', - 'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', - 'HATE': 'BLOCK_LOW_AND_ABOVE', - 'SEXUAL': 'BLOCK_LOW_AND_ABOVE', + 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', + 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', + 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', + 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', }), 'tools': None, }), @@ -79,7 +192,10 @@ dict({ 'history': list([ dict({ - 'parts': 'Answer in plain text. Keep it simple and to the point.', + 'parts': ''' + Answer in plain text. Keep it simple and to the point. + Only if the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. + ''', 'role': 'user', }), dict({ @@ -114,10 +230,10 @@ }), 'model_name': 'models/gemini-1.5-flash-latest', 'safety_settings': dict({ - 'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', - 'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', - 'HATE': 'BLOCK_LOW_AND_ABOVE', - 'SEXUAL': 'BLOCK_LOW_AND_ABOVE', + 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', + 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', + 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', + 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', }), 'tools': None, }), @@ -130,8 +246,8 @@ 'history': list([ dict({ 'parts': ''' - Call the intent tools to control Home Assistant. Just pass the name to the intent. Answer in plain text. Keep it simple and to the point. + Call the intent tools to control Home Assistant. Just pass the name to the intent. ''', 'role': 'user', }), @@ -167,10 +283,10 @@ }), 'model_name': 'models/gemini-1.5-flash-latest', 'safety_settings': dict({ - 'DANGEROUS': 'BLOCK_LOW_AND_ABOVE', - 'HARASSMENT': 'BLOCK_LOW_AND_ABOVE', - 'HATE': 'BLOCK_LOW_AND_ABOVE', - 'SEXUAL': 'BLOCK_LOW_AND_ABOVE', + 'DANGEROUS': 'BLOCK_MEDIUM_AND_ABOVE', + 'HARASSMENT': 'BLOCK_MEDIUM_AND_ABOVE', + 'HATE': 'BLOCK_MEDIUM_AND_ABOVE', + 'SEXUAL': 'BLOCK_MEDIUM_AND_ABOVE', }), 'tools': None, }), @@ -183,8 +299,8 @@ 'history': list([ dict({ 'parts': ''' - Call the intent tools to control Home Assistant. Just pass the name to the intent. Answer in plain text. Keep it simple and to the point. + Call the intent tools to control Home Assistant. Just pass the name to the intent. ''', 'role': 'user', }), diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index af7aebace35..4c208c240b8 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -2,12 +2,14 @@ from unittest.mock import AsyncMock, MagicMock, patch -from google.api_core.exceptions import ClientError +from google.api_core.exceptions import GoogleAPICallError +import google.generativeai.types as genai_types import pytest from syrupy.assertion import SnapshotAssertion import voluptuous as vol from homeassistant.components import conversation +from homeassistant.components.conversation import trace from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError @@ -150,6 +152,57 @@ async def test_default_prompt( assert mock_get_tools.called == (CONF_LLM_HASS_API in config_entry_options) +async def test_chat_history( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, + snapshot: SnapshotAssertion, +) -> None: + """Test that the agent keeps track of the chat history.""" + with patch("google.generativeai.GenerativeModel") as mock_model: + mock_chat = AsyncMock() + mock_model.return_value.start_chat.return_value = mock_chat + chat_response = MagicMock() + mock_chat.send_message_async.return_value = chat_response + mock_part = MagicMock() + mock_part.function_call = None + chat_response.parts = [mock_part] + chat_response.text = "1st model response" + mock_chat.history = [ + {"role": "user", "parts": "prompt"}, + {"role": "model", "parts": "Ok"}, + {"role": "user", "parts": "1st user request"}, + {"role": "model", "parts": "1st model response"}, + ] + result = await conversation.async_converse( + hass, + "1st user request", + None, + Context(), + agent_id=mock_config_entry.entry_id, + ) + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.as_dict()["speech"]["plain"]["speech"] + == "1st model response" + ) + chat_response.text = "2nd model response" + result = await conversation.async_converse( + hass, + "2nd user request", + result.conversation_id, + Context(), + agent_id=mock_config_entry.entry_id, + ) + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert ( + result.response.as_dict()["speech"]["plain"]["speech"] + == "2nd model response" + ) + + assert [tuple(mock_call) for mock_call in mock_model.mock_calls] == snapshot + + @patch( "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" ) @@ -233,6 +286,20 @@ async def test_function_call( ), ) + # Test conversating tracing + traces = trace.async_get_traces() + assert traces + last_trace = traces[-1].as_dict() + trace_events = last_trace.get("events", []) + assert [event["event_type"] for event in trace_events] == [ + trace.ConversationTraceEventType.ASYNC_PROCESS, + trace.ConversationTraceEventType.AGENT_DETAIL, + trace.ConversationTraceEventType.LLM_TOOL_CALL, + ] + # AGENT_DETAIL event contains the raw prompt passed to the model + detail_event = trace_events[1] + assert "Answer in plain text" in detail_event["data"]["messages"][0]["parts"] + @patch( "homeassistant.components.google_generative_ai_conversation.conversation.llm.AssistAPI.async_get_tools" @@ -325,7 +392,7 @@ async def test_error_handling( with patch("google.generativeai.GenerativeModel") as mock_model: mock_chat = AsyncMock() mock_model.return_value.start_chat.return_value = mock_chat - mock_chat.send_message_async.side_effect = ClientError("some error") + mock_chat.send_message_async.side_effect = GoogleAPICallError("some error") result = await conversation.async_converse( hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id ) @@ -340,7 +407,28 @@ async def test_error_handling( async def test_blocked_response( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component ) -> None: - """Test response was blocked.""" + """Test blocked response.""" + with patch("google.generativeai.GenerativeModel") as mock_model: + mock_chat = AsyncMock() + mock_model.return_value.start_chat.return_value = mock_chat + mock_chat.send_message_async.side_effect = genai_types.StopCandidateException( + "finish_reason: SAFETY\n" + ) + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + assert result.response.as_dict()["speech"]["plain"]["speech"] == ( + "The message got blocked by your safety settings" + ) + + +async def test_empty_response( + hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component +) -> None: + """Test empty response.""" with patch("google.generativeai.GenerativeModel") as mock_model: mock_chat = AsyncMock() mock_model.return_value.start_chat.return_value = mock_chat @@ -358,6 +446,32 @@ async def test_blocked_response( ) +async def test_invalid_llm_api( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, +) -> None: + """Test handling of invalid llm api.""" + hass.config_entries.async_update_entry( + mock_config_entry, + options={**mock_config_entry.options, CONF_LLM_HASS_API: "invalid_llm_api"}, + ) + + result = await conversation.async_converse( + hass, + "hello", + None, + Context(), + agent_id=mock_config_entry.entry_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ERROR, result + assert result.response.error_code == "unknown", result + assert result.response.as_dict()["speech"]["plain"]["speech"] == ( + "Error preparing LLM API: API invalid_llm_api not found" + ) + + async def test_template_error( hass: HomeAssistant, mock_config_entry: MockConfigEntry ) -> None: diff --git a/tests/components/mqtt/test_device_trigger.py b/tests/components/mqtt/test_device_trigger.py index 1ef80c0b81e..b01e40d311e 100644 --- a/tests/components/mqtt/test_device_trigger.py +++ b/tests/components/mqtt/test_device_trigger.py @@ -529,16 +529,16 @@ async def test_non_unique_triggers( async_fire_mqtt_message(hass, "foobar/triggers/button1", "short_press") await hass.async_block_till_done() assert len(calls) == 2 - assert calls[0].data["some"] == "press1" - assert calls[1].data["some"] == "press2" + all_calls = {calls[0].data["some"], calls[1].data["some"]} + assert all_calls == {"press1", "press2"} # Trigger second config references to same trigger # and triggers both attached instances. async_fire_mqtt_message(hass, "foobar/triggers/button2", "long_press") await hass.async_block_till_done() assert len(calls) == 2 - assert calls[0].data["some"] == "press1" - assert calls[1].data["some"] == "press2" + all_calls = {calls[0].data["some"], calls[1].data["some"]} + assert all_calls == {"press1", "press2"} # Removing the first trigger will clean up calls.clear() diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 6e81226ee14..5a6cb3dd7ed 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -4,6 +4,7 @@ import asyncio from collections.abc import Generator from copy import deepcopy from datetime import datetime, timedelta +from functools import partial import json import logging import socket @@ -1050,6 +1051,27 @@ async def test_subscribe_topic_not_initialize( await mqtt.async_subscribe(hass, "test-topic", record_calls) +async def test_subscribe_mqtt_config_entry_disabled( + hass: HomeAssistant, mqtt_mock: MqttMockHAClient +) -> None: + """Test the subscription of a topic when MQTT config entry is disabled.""" + mqtt_mock.connected = True + + mqtt_config_entry = hass.config_entries.async_entries(mqtt.DOMAIN)[0] + assert mqtt_config_entry.state is ConfigEntryState.LOADED + + assert await hass.config_entries.async_unload(mqtt_config_entry.entry_id) + assert mqtt_config_entry.state is ConfigEntryState.NOT_LOADED + + await hass.config_entries.async_set_disabled_by( + mqtt_config_entry.entry_id, ConfigEntryDisabler.USER + ) + mqtt_mock.connected = False + + with pytest.raises(HomeAssistantError, match=r".*MQTT is not enabled"): + await mqtt.async_subscribe(hass, "test-topic", record_calls) + + @patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0) @patch("homeassistant.components.mqtt.client.UNSUBSCRIBE_COOLDOWN", 0.2) async def test_subscribe_and_resubscribe( @@ -2912,8 +2934,8 @@ async def test_message_callback_exception_gets_logged( await mqtt_mock_entry() @callback - def bad_handler(*args) -> None: - """Record calls.""" + def bad_handler(msg: ReceiveMessage) -> None: + """Handle callback.""" raise ValueError("This is a bad message callback") await mqtt.async_subscribe(hass, "test-topic", bad_handler) @@ -2926,6 +2948,40 @@ async def test_message_callback_exception_gets_logged( ) +@pytest.mark.no_fail_on_log_exception +async def test_message_partial_callback_exception_gets_logged( + hass: HomeAssistant, + caplog: pytest.LogCaptureFixture, + mqtt_mock_entry: MqttMockHAClientGenerator, +) -> None: + """Test exception raised by message handler.""" + await mqtt_mock_entry() + + @callback + def bad_handler(msg: ReceiveMessage) -> None: + """Handle callback.""" + raise ValueError("This is a bad message callback") + + def parial_handler( + msg_callback: MessageCallbackType, + attributes: set[str], + msg: ReceiveMessage, + ) -> None: + """Partial callback handler.""" + msg_callback(msg) + + await mqtt.async_subscribe( + hass, "test-topic", partial(parial_handler, bad_handler, {"some_attr"}) + ) + async_fire_mqtt_message(hass, "test-topic", "test") + await hass.async_block_till_done() + + assert ( + "Exception in bad_handler when handling msg on 'test-topic':" + " 'test'" in caplog.text + ) + + async def test_mqtt_ws_subscription( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, @@ -3787,7 +3843,7 @@ async def test_unload_config_entry( async def test_publish_or_subscribe_without_valid_config_entry( hass: HomeAssistant, record_calls: MessageCallbackType ) -> None: - """Test internal publish function with bas use cases.""" + """Test internal publish function with bad use cases.""" with pytest.raises(HomeAssistantError): await mqtt.async_publish( hass, "some-topic", "test-payload", qos=0, retain=False, encoding=None diff --git a/tests/components/nws/conftest.py b/tests/components/nws/conftest.py index 48401fe87ba..65276a1a115 100644 --- a/tests/components/nws/conftest.py +++ b/tests/components/nws/conftest.py @@ -11,8 +11,12 @@ from .const import DEFAULT_FORECAST, DEFAULT_OBSERVATION @pytest.fixture def mock_simple_nws(): """Mock pynws SimpleNWS with default values.""" - - with patch("homeassistant.components.nws.SimpleNWS") as mock_nws: + # set RETRY_STOP and RETRY_INTERVAL to avoid retries inside pynws in tests + with ( + patch("homeassistant.components.nws.SimpleNWS") as mock_nws, + patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0), + patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0), + ): instance = mock_nws.return_value instance.set_station = AsyncMock(return_value=None) instance.update_observation = AsyncMock(return_value=None) @@ -29,7 +33,12 @@ def mock_simple_nws(): @pytest.fixture def mock_simple_nws_times_out(): """Mock pynws SimpleNWS that times out.""" - with patch("homeassistant.components.nws.SimpleNWS") as mock_nws: + # set RETRY_STOP and RETRY_INTERVAL to avoid retries inside pynws in tests + with ( + patch("homeassistant.components.nws.SimpleNWS") as mock_nws, + patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0), + patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0), + ): instance = mock_nws.return_value instance.set_station = AsyncMock(side_effect=asyncio.TimeoutError) instance.update_observation = AsyncMock(side_effect=asyncio.TimeoutError) diff --git a/tests/components/nws/test_weather.py b/tests/components/nws/test_weather.py index 32cbfe4befe..5406636c324 100644 --- a/tests/components/nws/test_weather.py +++ b/tests/components/nws/test_weather.py @@ -1,7 +1,6 @@ """Tests for the NWS weather component.""" from datetime import timedelta -from unittest.mock import patch import aiohttp from freezegun.api import FrozenDateTimeFactory @@ -24,7 +23,6 @@ from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN from homeassistant.core import HomeAssistant from homeassistant.helpers import entity_registry as er from homeassistant.setup import async_setup_component -import homeassistant.util.dt as dt_util from homeassistant.util.unit_system import METRIC_SYSTEM, US_CUSTOMARY_SYSTEM from .const import ( @@ -127,47 +125,43 @@ async def test_data_caching_error_observation( caplog, ) -> None: """Test caching of data with errors.""" - with ( - patch("homeassistant.components.nws.coordinator.RETRY_STOP", 0), - patch("homeassistant.components.nws.coordinator.RETRY_INTERVAL", 0), - ): - instance = mock_simple_nws.return_value + instance = mock_simple_nws.return_value - entry = MockConfigEntry( - domain=nws.DOMAIN, - data=NWS_CONFIG, - ) - entry.add_to_hass(hass) - await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() + entry = MockConfigEntry( + domain=nws.DOMAIN, + data=NWS_CONFIG, + ) + entry.add_to_hass(hass) + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() - state = hass.states.get("weather.abc") - assert state.state == "sunny" + state = hass.states.get("weather.abc") + assert state.state == "sunny" - # data is still valid even when update fails - instance.update_observation.side_effect = NwsNoDataError("Test") + # data is still valid even when update fails + instance.update_observation.side_effect = NwsNoDataError("Test") - freezer.tick(DEFAULT_SCAN_INTERVAL + timedelta(seconds=100)) - async_fire_time_changed(hass) - await hass.async_block_till_done() + freezer.tick(DEFAULT_SCAN_INTERVAL + timedelta(seconds=100)) + async_fire_time_changed(hass) + await hass.async_block_till_done() - state = hass.states.get("weather.abc") - assert state.state == "sunny" + state = hass.states.get("weather.abc") + assert state.state == "sunny" - assert ( - "NWS observation update failed, but data still valid. Last success: " - in caplog.text - ) + assert ( + "NWS observation update failed, but data still valid. Last success: " + in caplog.text + ) - # data is no longer valid after OBSERVATION_VALID_TIME - freezer.tick(OBSERVATION_VALID_TIME + timedelta(seconds=1)) - async_fire_time_changed(hass) - await hass.async_block_till_done() + # data is no longer valid after OBSERVATION_VALID_TIME + freezer.tick(OBSERVATION_VALID_TIME + timedelta(seconds=1)) + async_fire_time_changed(hass) + await hass.async_block_till_done() - state = hass.states.get("weather.abc") - assert state.state == STATE_UNAVAILABLE + state = hass.states.get("weather.abc") + assert state.state == STATE_UNAVAILABLE - assert "Error fetching NWS observation station ABC data: Test" in caplog.text + assert "Error fetching NWS observation station ABC data: Test" in caplog.text async def test_no_data_error_observation( @@ -302,26 +296,23 @@ async def test_error_observation( hass: HomeAssistant, mock_simple_nws, no_sensor ) -> None: """Test error during update observation.""" - utc_time = dt_util.utcnow() - with patch("homeassistant.components.nws.coordinator.utcnow") as mock_utc: - mock_utc.return_value = utc_time - instance = mock_simple_nws.return_value - # first update fails - instance.update_observation.side_effect = aiohttp.ClientError + instance = mock_simple_nws.return_value + # first update fails + instance.update_observation.side_effect = aiohttp.ClientError - entry = MockConfigEntry( - domain=nws.DOMAIN, - data=NWS_CONFIG, - ) - entry.add_to_hass(hass) - await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() + entry = MockConfigEntry( + domain=nws.DOMAIN, + data=NWS_CONFIG, + ) + entry.add_to_hass(hass) + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() - instance.update_observation.assert_called_once() + instance.update_observation.assert_called_once() - state = hass.states.get("weather.abc") - assert state - assert state.state == STATE_UNAVAILABLE + state = hass.states.get("weather.abc") + assert state + assert state.state == STATE_UNAVAILABLE async def test_new_config_entry(hass: HomeAssistant, no_sensor) -> None: diff --git a/tests/components/ollama/test_conversation.py b/tests/components/ollama/test_conversation.py index 080d0d34f2d..b6f0be3c414 100644 --- a/tests/components/ollama/test_conversation.py +++ b/tests/components/ollama/test_conversation.py @@ -6,6 +6,7 @@ from ollama import Message, ResponseError import pytest from homeassistant.components import conversation, ollama +from homeassistant.components.conversation import trace from homeassistant.components.homeassistant.exposed_entities import async_expose_entity from homeassistant.const import ATTR_FRIENDLY_NAME, MATCH_ALL from homeassistant.core import Context, HomeAssistant @@ -110,6 +111,19 @@ async def test_chat( ), result assert result.response.speech["plain"]["speech"] == "test response" + # Test Conversation tracing + traces = trace.async_get_traces() + assert traces + last_trace = traces[-1].as_dict() + trace_events = last_trace.get("events", []) + assert [event["event_type"] for event in trace_events] == [ + trace.ConversationTraceEventType.ASYNC_PROCESS, + trace.ConversationTraceEventType.AGENT_DETAIL, + ] + # AGENT_DETAIL event contains the raw prompt passed to the model + detail_event = trace_events[1] + assert "The current time is" in detail_event["data"]["messages"][0]["content"] + async def test_message_history_trimming( hass: HomeAssistant, mock_config_entry: MockConfigEntry, mock_init_component diff --git a/tests/components/openai_conversation/test_config_flow.py b/tests/components/openai_conversation/test_config_flow.py index 57f03d0c0bf..234e518b3c5 100644 --- a/tests/components/openai_conversation/test_config_flow.py +++ b/tests/components/openai_conversation/test_config_flow.py @@ -9,9 +9,17 @@ import pytest from homeassistant import config_entries from homeassistant.components.openai_conversation.const import ( CONF_CHAT_MODEL, - DEFAULT_CHAT_MODEL, + CONF_MAX_TOKENS, + CONF_PROMPT, + CONF_RECOMMENDED, + CONF_TEMPERATURE, + CONF_TOP_P, DOMAIN, + RECOMMENDED_CHAT_MODEL, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TOP_P, ) +from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType @@ -75,7 +83,7 @@ async def test_options( assert options["type"] is FlowResultType.CREATE_ENTRY assert options["data"]["prompt"] == "Speak like a pirate" assert options["data"]["max_tokens"] == 200 - assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL + assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL @pytest.mark.parametrize( @@ -115,3 +123,78 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non assert result2["type"] is FlowResultType.FORM assert result2["errors"] == {"base": error} + + +@pytest.mark.parametrize( + ("current_options", "new_options", "expected_options"), + [ + ( + { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: "none", + CONF_PROMPT: "bla", + }, + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + }, + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_TOP_P: RECOMMENDED_TOP_P, + CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + }, + ), + ( + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_TOP_P: RECOMMENDED_TOP_P, + CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + }, + { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: "assist", + CONF_PROMPT: "", + }, + { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: "assist", + CONF_PROMPT: "", + }, + ), + ], +) +async def test_options_switching( + hass: HomeAssistant, + mock_config_entry, + mock_init_component, + current_options, + new_options, + expected_options, +) -> None: + """Test the options form.""" + hass.config_entries.async_update_entry(mock_config_entry, options=current_options) + options_flow = await hass.config_entries.options.async_init( + mock_config_entry.entry_id + ) + if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED): + options_flow = await hass.config_entries.options.async_configure( + options_flow["flow_id"], + { + **current_options, + CONF_RECOMMENDED: new_options[CONF_RECOMMENDED], + }, + ) + options = await hass.config_entries.options.async_configure( + options_flow["flow_id"], + new_options, + ) + await hass.async_block_till_done() + assert options["type"] is FlowResultType.CREATE_ENTRY + assert options["data"] == expected_options diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 319295374a7..3fa5c307b6d 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -15,6 +15,7 @@ from syrupy.assertion import SnapshotAssertion import voluptuous as vol from homeassistant.components import conversation +from homeassistant.components.conversation import trace from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import Context, HomeAssistant from homeassistant.exceptions import HomeAssistantError @@ -200,6 +201,20 @@ async def test_function_call( ), ) + # Test Conversation tracing + traces = trace.async_get_traces() + assert traces + last_trace = traces[-1].as_dict() + trace_events = last_trace.get("events", []) + assert [event["event_type"] for event in trace_events] == [ + trace.ConversationTraceEventType.ASYNC_PROCESS, + trace.ConversationTraceEventType.AGENT_DETAIL, + trace.ConversationTraceEventType.LLM_TOOL_CALL, + ] + # AGENT_DETAIL event contains the raw prompt passed to the model + detail_event = trace_events[1] + assert "Answer in plain text" in detail_event["data"]["messages"][0]["content"] + @patch( "homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools" diff --git a/tests/components/plaato/__init__.py b/tests/components/plaato/__init__.py index 6c66478eba1..5a2b2a68d44 100644 --- a/tests/components/plaato/__init__.py +++ b/tests/components/plaato/__init__.py @@ -2,6 +2,7 @@ from unittest.mock import patch +from freezegun import freeze_time from pyplaato.models.airlock import PlaatoAirlock from pyplaato.models.device import PlaatoDeviceType from pyplaato.models.keg import PlaatoKeg @@ -23,6 +24,7 @@ AIRLOCK_DATA = {} KEG_DATA = {} +@freeze_time("2024-05-24 12:00:00", tz_offset=0) async def init_integration( hass: HomeAssistant, device_type: PlaatoDeviceType ) -> MockConfigEntry: diff --git a/tests/components/shelly/test_climate.py b/tests/components/shelly/test_climate.py index 241c6a00724..aac14c24288 100644 --- a/tests/components/shelly/test_climate.py +++ b/tests/components/shelly/test_climate.py @@ -492,7 +492,6 @@ async def test_block_set_mode_auth_error( {ATTR_ENTITY_ID: ENTITY_ID, ATTR_HVAC_MODE: HVACMode.HEAT}, blocking=True, ) - await hass.async_block_till_done() assert entry.state is ConfigEntryState.LOADED diff --git a/tests/components/shelly/test_number.py b/tests/components/shelly/test_number.py index 0b9fee9e47f..a5f64409d09 100644 --- a/tests/components/shelly/test_number.py +++ b/tests/components/shelly/test_number.py @@ -227,7 +227,6 @@ async def test_block_set_value_auth_error( {ATTR_ENTITY_ID: "number.test_name_valve_position", ATTR_VALUE: 30}, blocking=True, ) - await hass.async_block_till_done() assert entry.state is ConfigEntryState.LOADED diff --git a/tests/components/shelly/test_sensor.py b/tests/components/shelly/test_sensor.py index ceaa9b66b8d..e7bac38c7fd 100644 --- a/tests/components/shelly/test_sensor.py +++ b/tests/components/shelly/test_sensor.py @@ -618,7 +618,6 @@ async def test_rpc_sleeping_update_entity_service( service_data={ATTR_ENTITY_ID: entity_id}, blocking=True, ) - await hass.async_block_till_done() # Entity should be available after update_entity service call state = hass.states.get(entity_id) @@ -667,7 +666,6 @@ async def test_block_sleeping_update_entity_service( service_data={ATTR_ENTITY_ID: entity_id}, blocking=True, ) - await hass.async_block_till_done() # Entity should be available after update_entity service call state = hass.states.get(entity_id) diff --git a/tests/components/shelly/test_switch.py b/tests/components/shelly/test_switch.py index 3bcb262bee1..212fd4e6bab 100644 --- a/tests/components/shelly/test_switch.py +++ b/tests/components/shelly/test_switch.py @@ -230,7 +230,6 @@ async def test_block_set_state_auth_error( {ATTR_ENTITY_ID: "switch.test_name_channel_1"}, blocking=True, ) - await hass.async_block_till_done() assert entry.state is ConfigEntryState.LOADED @@ -374,7 +373,6 @@ async def test_rpc_auth_error( {ATTR_ENTITY_ID: "switch.test_switch_0"}, blocking=True, ) - await hass.async_block_till_done() assert entry.state is ConfigEntryState.LOADED diff --git a/tests/components/shelly/test_update.py b/tests/components/shelly/test_update.py index 0f26fd14d12..b4ec42762bb 100644 --- a/tests/components/shelly/test_update.py +++ b/tests/components/shelly/test_update.py @@ -207,7 +207,6 @@ async def test_block_update_auth_error( {ATTR_ENTITY_ID: "update.test_name_firmware_update"}, blocking=True, ) - await hass.async_block_till_done() assert entry.state is ConfigEntryState.LOADED @@ -669,7 +668,6 @@ async def test_rpc_update_auth_error( blocking=True, ) - await hass.async_block_till_done() assert entry.state is ConfigEntryState.LOADED flows = hass.config_entries.flow.async_progress() diff --git a/tests/components/switcher_kis/conftest.py b/tests/components/switcher_kis/conftest.py index 5f04df7dc66..eb3b92120e1 100644 --- a/tests/components/switcher_kis/conftest.py +++ b/tests/components/switcher_kis/conftest.py @@ -18,9 +18,15 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]: @pytest.fixture def mock_bridge(request): """Return a mocked SwitcherBridge.""" - with patch( - "homeassistant.components.switcher_kis.utils.SwitcherBridge", autospec=True - ) as bridge_mock: + with ( + patch( + "homeassistant.components.switcher_kis.SwitcherBridge", autospec=True + ) as bridge_mock, + patch( + "homeassistant.components.switcher_kis.utils.SwitcherBridge", + new=bridge_mock, + ), + ): bridge = bridge_mock.return_value bridge.devices = [] diff --git a/tests/components/switcher_kis/test_init.py b/tests/components/switcher_kis/test_init.py index 70eb518820c..14217a7e044 100644 --- a/tests/components/switcher_kis/test_init.py +++ b/tests/components/switcher_kis/test_init.py @@ -4,11 +4,7 @@ from datetime import timedelta import pytest -from homeassistant.components.switcher_kis.const import ( - DATA_DEVICE, - DOMAIN, - MAX_UPDATE_INTERVAL_SEC, -) +from homeassistant.components.switcher_kis.const import MAX_UPDATE_INTERVAL_SEC from homeassistant.config_entries import ConfigEntryState from homeassistant.const import STATE_UNAVAILABLE from homeassistant.core import HomeAssistant @@ -24,15 +20,14 @@ async def test_update_fail( hass: HomeAssistant, mock_bridge, caplog: pytest.LogCaptureFixture ) -> None: """Test entities state unavailable when updates fail..""" - await init_integration(hass) + entry = await init_integration(hass) assert mock_bridge mock_bridge.mock_callbacks(DUMMY_SWITCHER_DEVICES) await hass.async_block_till_done() assert mock_bridge.is_running is True - assert len(hass.data[DOMAIN]) == 2 - assert len(hass.data[DOMAIN][DATA_DEVICE]) == 2 + assert len(entry.runtime_data) == 2 async_fire_time_changed( hass, dt_util.utcnow() + timedelta(seconds=MAX_UPDATE_INTERVAL_SEC + 1) @@ -77,11 +72,9 @@ async def test_entry_unload(hass: HomeAssistant, mock_bridge) -> None: assert entry.state is ConfigEntryState.LOADED assert mock_bridge.is_running is True - assert len(hass.data[DOMAIN]) == 2 await hass.config_entries.async_unload(entry.entry_id) await hass.async_block_till_done() assert entry.state is ConfigEntryState.NOT_LOADED assert mock_bridge.is_running is False - assert len(hass.data[DOMAIN]) == 0 diff --git a/tests/components/switcher_kis/test_sensor.py b/tests/components/switcher_kis/test_sensor.py index f61cdd5a010..bfe1b2c84dd 100644 --- a/tests/components/switcher_kis/test_sensor.py +++ b/tests/components/switcher_kis/test_sensor.py @@ -2,7 +2,6 @@ import pytest -from homeassistant.components.switcher_kis.const import DATA_DEVICE, DOMAIN from homeassistant.core import HomeAssistant from homeassistant.helpers import entity_registry as er from homeassistant.util import slugify @@ -32,12 +31,11 @@ DEVICE_SENSORS_TUPLE = ( @pytest.mark.parametrize("mock_bridge", [DUMMY_SWITCHER_DEVICES], indirect=True) async def test_sensor_platform(hass: HomeAssistant, mock_bridge) -> None: """Test sensor platform.""" - await init_integration(hass) + entry = await init_integration(hass) assert mock_bridge assert mock_bridge.is_running is True - assert len(hass.data[DOMAIN]) == 2 - assert len(hass.data[DOMAIN][DATA_DEVICE]) == 2 + assert len(entry.runtime_data) == 2 for device, sensors in DEVICE_SENSORS_TUPLE: for sensor, field in sensors: diff --git a/tests/components/teslemetry/snapshots/test_device_tracker.ambr b/tests/components/teslemetry/snapshots/test_device_tracker.ambr new file mode 100644 index 00000000000..369a3e3a2b9 --- /dev/null +++ b/tests/components/teslemetry/snapshots/test_device_tracker.ambr @@ -0,0 +1,101 @@ +# serializer version: 1 +# name: test_device_tracker[device_tracker.test_location-entry] + EntityRegistryEntrySnapshot({ + 'aliases': set({ + }), + 'area_id': None, + 'capabilities': None, + 'config_entry_id': , + 'device_class': None, + 'device_id': , + 'disabled_by': None, + 'domain': 'device_tracker', + 'entity_category': , + 'entity_id': 'device_tracker.test_location', + 'has_entity_name': True, + 'hidden_by': None, + 'icon': None, + 'id': , + 'labels': set({ + }), + 'name': None, + 'options': dict({ + }), + 'original_device_class': None, + 'original_icon': None, + 'original_name': 'Location', + 'platform': 'teslemetry', + 'previous_unique_id': None, + 'supported_features': 0, + 'translation_key': 'location', + 'unique_id': 'VINVINVIN-location', + 'unit_of_measurement': None, + }) +# --- +# name: test_device_tracker[device_tracker.test_location-state] + StateSnapshot({ + 'attributes': ReadOnlyDict({ + 'friendly_name': 'Test Location', + 'gps_accuracy': 0, + 'latitude': -30.222626, + 'longitude': -97.6236871, + 'source_type': , + }), + 'context': , + 'entity_id': 'device_tracker.test_location', + 'last_changed': , + 'last_reported': , + 'last_updated': , + 'state': 'not_home', + }) +# --- +# name: test_device_tracker[device_tracker.test_route-entry] + EntityRegistryEntrySnapshot({ + 'aliases': set({ + }), + 'area_id': None, + 'capabilities': None, + 'config_entry_id': , + 'device_class': None, + 'device_id': , + 'disabled_by': None, + 'domain': 'device_tracker', + 'entity_category': , + 'entity_id': 'device_tracker.test_route', + 'has_entity_name': True, + 'hidden_by': None, + 'icon': None, + 'id': , + 'labels': set({ + }), + 'name': None, + 'options': dict({ + }), + 'original_device_class': None, + 'original_icon': None, + 'original_name': 'Route', + 'platform': 'teslemetry', + 'previous_unique_id': None, + 'supported_features': 0, + 'translation_key': 'route', + 'unique_id': 'VINVINVIN-route', + 'unit_of_measurement': None, + }) +# --- +# name: test_device_tracker[device_tracker.test_route-state] + StateSnapshot({ + 'attributes': ReadOnlyDict({ + 'friendly_name': 'Test Route', + 'gps_accuracy': 0, + 'latitude': 30.2226265, + 'longitude': -97.6236871, + 'source_type': , + }), + 'context': , + 'entity_id': 'device_tracker.test_route', + 'last_changed': , + 'last_reported': , + 'last_updated': , + 'state': 'not_home', + }) +# --- diff --git a/tests/components/teslemetry/test_device_tracker.py b/tests/components/teslemetry/test_device_tracker.py new file mode 100644 index 00000000000..55deaefdab5 --- /dev/null +++ b/tests/components/teslemetry/test_device_tracker.py @@ -0,0 +1,33 @@ +"""Test the Teslemetry device tracker platform.""" + +from syrupy import SnapshotAssertion +from tesla_fleet_api.exceptions import VehicleOffline + +from homeassistant.const import STATE_UNKNOWN, Platform +from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er + +from . import assert_entities, setup_platform + + +async def test_device_tracker( + hass: HomeAssistant, + snapshot: SnapshotAssertion, + entity_registry: er.EntityRegistry, +) -> None: + """Tests that the device tracker entities are correct.""" + + entry = await setup_platform(hass, [Platform.DEVICE_TRACKER]) + assert_entities(hass, entry.entry_id, entity_registry, snapshot) + + +async def test_device_tracker_offline( + hass: HomeAssistant, + mock_vehicle_data, +) -> None: + """Tests that the device tracker entities are correct when offline.""" + + mock_vehicle_data.side_effect = VehicleOffline + await setup_platform(hass, [Platform.DEVICE_TRACKER]) + state = hass.states.get("device_tracker.test_location") + assert state.state == STATE_UNKNOWN diff --git a/tests/components/webostv/test_device_trigger.py b/tests/components/webostv/test_device_trigger.py index 5205f6ae7a1..8d62d4e0b17 100644 --- a/tests/components/webostv/test_device_trigger.py +++ b/tests/components/webostv/test_device_trigger.py @@ -92,7 +92,6 @@ async def test_if_fires_on_turn_on_request(hass: HomeAssistant, calls, client) - blocking=True, ) - await hass.async_block_till_done() assert len(calls) == 2 assert calls[0].data["some"] == device.id assert calls[0].data["id"] == 0 diff --git a/tests/components/webostv/test_trigger.py b/tests/components/webostv/test_trigger.py index dd119bd0d5a..73c55df8807 100644 --- a/tests/components/webostv/test_trigger.py +++ b/tests/components/webostv/test_trigger.py @@ -56,7 +56,6 @@ async def test_webostv_turn_on_trigger_device_id( blocking=True, ) - await hass.async_block_till_done() assert len(calls) == 1 assert calls[0].data["some"] == device.id assert calls[0].data["id"] == 0 @@ -74,7 +73,6 @@ async def test_webostv_turn_on_trigger_device_id( blocking=True, ) - await hass.async_block_till_done() assert len(calls) == 0 @@ -113,7 +111,6 @@ async def test_webostv_turn_on_trigger_entity_id( blocking=True, ) - await hass.async_block_till_done() assert len(calls) == 1 assert calls[0].data["some"] == ENTITY_ID assert calls[0].data["id"] == 0 diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 70c28545483..e3308b89061 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -174,9 +174,9 @@ async def test_assist_api_prompt( ) api = llm.async_get_api(hass, "assist") prompt = await api.async_get_api_prompt(tool_input) - assert ( - prompt - == "Call the intent tools to control Home Assistant. Just pass the name to the intent." + assert prompt == ( + "Call the intent tools to control Home Assistant." + " Just pass the name to the intent." ) entry = MockConfigEntry(title=None) @@ -190,18 +190,18 @@ async def test_assist_api_prompt( suggested_area="Test Area", ).id prompt = await api.async_get_api_prompt(tool_input) - assert ( - prompt - == "Call the intent tools to control Home Assistant. Just pass the name to the intent. You are in Test Area." + assert prompt == ( + "Call the intent tools to control Home Assistant." + " Just pass the name to the intent. You are in Test Area." ) floor = floor_registry.async_create("second floor") area = area_registry.async_get_area_by_name("Test Area") area_registry.async_update(area.id, floor_id=floor.floor_id) prompt = await api.async_get_api_prompt(tool_input) - assert ( - prompt - == "Call the intent tools to control Home Assistant. Just pass the name to the intent. You are in Test Area (second floor)." + assert prompt == ( + "Call the intent tools to control Home Assistant." + " Just pass the name to the intent. You are in Test Area (second floor)." ) context.user_id = "12345" @@ -210,7 +210,8 @@ async def test_assist_api_prompt( mock_user.name = "Test User" with patch("homeassistant.auth.AuthManager.async_get_user", return_value=mock_user): prompt = await api.async_get_api_prompt(tool_input) - assert ( - prompt - == "Call the intent tools to control Home Assistant. Just pass the name to the intent. You are in Test Area (second floor). The user name is Test User." + assert prompt == ( + "Call the intent tools to control Home Assistant." + " Just pass the name to the intent. You are in Test Area (second floor)." + " The user name is Test User." )