diff --git a/homeassistant/components/telegram_bot/__init__.py b/homeassistant/components/telegram_bot/__init__.py index b9a032d7f28..4fdb87f9fa6 100644 --- a/homeassistant/components/telegram_bot/__init__.py +++ b/homeassistant/components/telegram_bot/__init__.py @@ -36,7 +36,13 @@ from homeassistant.const import ( HTTP_BEARER_AUTHENTICATION, HTTP_DIGEST_AUTHENTICATION, ) -from homeassistant.core import Context, HomeAssistant, ServiceCall +from homeassistant.core import ( + Context, + HomeAssistant, + ServiceCall, + ServiceResponse, + SupportsResponse, +) from homeassistant.helpers import config_validation as cv, issue_registry as ir from homeassistant.helpers.typing import ConfigType from homeassistant.loader import async_get_loaded_integration @@ -398,15 +404,18 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: hass, bot, p_config.get(CONF_ALLOWED_CHAT_IDS), p_config.get(ATTR_PARSER) ) - async def async_send_telegram_message(service: ServiceCall) -> None: + async def async_send_telegram_message(service: ServiceCall) -> ServiceResponse: """Handle sending Telegram Bot message service calls.""" msgtype = service.service kwargs = dict(service.data) _LOGGER.debug("New telegram message %s: %s", msgtype, kwargs) + messages = None if msgtype == SERVICE_SEND_MESSAGE: - await notify_service.send_message(context=service.context, **kwargs) + messages = await notify_service.send_message( + context=service.context, **kwargs + ) elif msgtype in [ SERVICE_SEND_PHOTO, SERVICE_SEND_ANIMATION, @@ -414,13 +423,19 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: SERVICE_SEND_VOICE, SERVICE_SEND_DOCUMENT, ]: - await notify_service.send_file(msgtype, context=service.context, **kwargs) + messages = await notify_service.send_file( + msgtype, context=service.context, **kwargs + ) elif msgtype == SERVICE_SEND_STICKER: - await notify_service.send_sticker(context=service.context, **kwargs) + messages = await notify_service.send_sticker( + context=service.context, **kwargs + ) elif msgtype == SERVICE_SEND_LOCATION: - await notify_service.send_location(context=service.context, **kwargs) + messages = await notify_service.send_location( + context=service.context, **kwargs + ) elif msgtype == SERVICE_SEND_POLL: - await notify_service.send_poll(context=service.context, **kwargs) + messages = await notify_service.send_poll(context=service.context, **kwargs) elif msgtype == SERVICE_ANSWER_CALLBACK_QUERY: await notify_service.answer_callback_query( context=service.context, **kwargs @@ -432,10 +447,37 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: msgtype, context=service.context, **kwargs ) + if service.return_response and messages: + return { + "chats": [ + {"chat_id": cid, "message_id": mid} for cid, mid in messages.items() + ] + } + return None + # Register notification services for service_notif, schema in SERVICE_MAP.items(): + supports_response = SupportsResponse.NONE + + if service_notif in [ + SERVICE_SEND_MESSAGE, + SERVICE_SEND_PHOTO, + SERVICE_SEND_ANIMATION, + SERVICE_SEND_VIDEO, + SERVICE_SEND_VOICE, + SERVICE_SEND_DOCUMENT, + SERVICE_SEND_STICKER, + SERVICE_SEND_LOCATION, + SERVICE_SEND_POLL, + ]: + supports_response = SupportsResponse.OPTIONAL + hass.services.async_register( - DOMAIN, service_notif, async_send_telegram_message, schema=schema + DOMAIN, + service_notif, + async_send_telegram_message, + schema=schema, + supports_response=supports_response, ) return True @@ -694,9 +736,10 @@ class TelegramNotificationService: title = kwargs.get(ATTR_TITLE) text = f"{title}\n{message}" if title else message params = self._get_msg_kwargs(kwargs) + msg_ids = {} for chat_id in self._get_target_chat_ids(target): _LOGGER.debug("Send message in chat ID %s with params: %s", chat_id, params) - await self._send_msg( + msg = await self._send_msg( self.bot.send_message, "Error sending message", params[ATTR_MESSAGE_TAG], @@ -711,6 +754,8 @@ class TelegramNotificationService: message_thread_id=params[ATTR_MESSAGE_THREAD_ID], context=context, ) + msg_ids[chat_id] = msg.id + return msg_ids async def delete_message(self, chat_id=None, context=None, **kwargs): """Delete a previously sent message.""" @@ -829,12 +874,13 @@ class TelegramNotificationService: ), ) + msg_ids = {} if file_content: for chat_id in self._get_target_chat_ids(target): _LOGGER.debug("Sending file to chat ID %s", chat_id) if file_type == SERVICE_SEND_PHOTO: - await self._send_msg( + msg = await self._send_msg( self.bot.send_photo, "Error sending photo", params[ATTR_MESSAGE_TAG], @@ -851,7 +897,7 @@ class TelegramNotificationService: ) elif file_type == SERVICE_SEND_STICKER: - await self._send_msg( + msg = await self._send_msg( self.bot.send_sticker, "Error sending sticker", params[ATTR_MESSAGE_TAG], @@ -866,7 +912,7 @@ class TelegramNotificationService: ) elif file_type == SERVICE_SEND_VIDEO: - await self._send_msg( + msg = await self._send_msg( self.bot.send_video, "Error sending video", params[ATTR_MESSAGE_TAG], @@ -882,7 +928,7 @@ class TelegramNotificationService: context=context, ) elif file_type == SERVICE_SEND_DOCUMENT: - await self._send_msg( + msg = await self._send_msg( self.bot.send_document, "Error sending document", params[ATTR_MESSAGE_TAG], @@ -898,7 +944,7 @@ class TelegramNotificationService: context=context, ) elif file_type == SERVICE_SEND_VOICE: - await self._send_msg( + msg = await self._send_msg( self.bot.send_voice, "Error sending voice", params[ATTR_MESSAGE_TAG], @@ -913,7 +959,7 @@ class TelegramNotificationService: context=context, ) elif file_type == SERVICE_SEND_ANIMATION: - await self._send_msg( + msg = await self._send_msg( self.bot.send_animation, "Error sending animation", params[ATTR_MESSAGE_TAG], @@ -929,17 +975,22 @@ class TelegramNotificationService: context=context, ) + msg_ids[chat_id] = msg.id file_content.seek(0) else: _LOGGER.error("Can't send file with kwargs: %s", kwargs) - async def send_sticker(self, target=None, context=None, **kwargs): + return msg_ids + + async def send_sticker(self, target=None, context=None, **kwargs) -> dict: """Send a sticker from a telegram sticker pack.""" params = self._get_msg_kwargs(kwargs) stickerid = kwargs.get(ATTR_STICKER_ID) + + msg_ids = {} if stickerid: for chat_id in self._get_target_chat_ids(target): - await self._send_msg( + msg = await self._send_msg( self.bot.send_sticker, "Error sending sticker", params[ATTR_MESSAGE_TAG], @@ -952,8 +1003,9 @@ class TelegramNotificationService: message_thread_id=params[ATTR_MESSAGE_THREAD_ID], context=context, ) - else: - await self.send_file(SERVICE_SEND_STICKER, target, **kwargs) + msg_ids[chat_id] = msg.id + return msg_ids + return await self.send_file(SERVICE_SEND_STICKER, target, **kwargs) async def send_location( self, latitude, longitude, target=None, context=None, **kwargs @@ -962,11 +1014,12 @@ class TelegramNotificationService: latitude = float(latitude) longitude = float(longitude) params = self._get_msg_kwargs(kwargs) + msg_ids = {} for chat_id in self._get_target_chat_ids(target): _LOGGER.debug( "Send location %s/%s to chat ID %s", latitude, longitude, chat_id ) - await self._send_msg( + msg = await self._send_msg( self.bot.send_location, "Error sending location", params[ATTR_MESSAGE_TAG], @@ -979,6 +1032,8 @@ class TelegramNotificationService: message_thread_id=params[ATTR_MESSAGE_THREAD_ID], context=context, ) + msg_ids[chat_id] = msg.id + return msg_ids async def send_poll( self, @@ -993,9 +1048,10 @@ class TelegramNotificationService: """Send a poll.""" params = self._get_msg_kwargs(kwargs) openperiod = kwargs.get(ATTR_OPEN_PERIOD) + msg_ids = {} for chat_id in self._get_target_chat_ids(target): _LOGGER.debug("Send poll '%s' to chat ID %s", question, chat_id) - await self._send_msg( + msg = await self._send_msg( self.bot.send_poll, "Error sending poll", params[ATTR_MESSAGE_TAG], @@ -1011,6 +1067,8 @@ class TelegramNotificationService: message_thread_id=params[ATTR_MESSAGE_THREAD_ID], context=context, ) + msg_ids[chat_id] = msg.id + return msg_ids async def leave_chat(self, chat_id=None, context=None): """Remove bot from chat.""" diff --git a/tests/components/telegram_bot/conftest.py b/tests/components/telegram_bot/conftest.py index 93137c3815e..f15db7eba2b 100644 --- a/tests/components/telegram_bot/conftest.py +++ b/tests/components/telegram_bot/conftest.py @@ -105,6 +105,14 @@ def mock_external_calls() -> Generator[None]: patch.object(BotMock, "get_me", return_value=test_user), patch.object(BotMock, "bot", test_user), patch.object(BotMock, "send_message", return_value=message), + patch.object(BotMock, "send_photo", return_value=message), + patch.object(BotMock, "send_sticker", return_value=message), + patch.object(BotMock, "send_video", return_value=message), + patch.object(BotMock, "send_document", return_value=message), + patch.object(BotMock, "send_voice", return_value=message), + patch.object(BotMock, "send_animation", return_value=message), + patch.object(BotMock, "send_location", return_value=message), + patch.object(BotMock, "send_poll", return_value=message), patch("telegram.ext.Updater._bootstrap"), ): yield diff --git a/tests/components/telegram_bot/test_telegram_bot.py b/tests/components/telegram_bot/test_telegram_bot.py index bdf6ba72fcc..be6b5b31325 100644 --- a/tests/components/telegram_bot/test_telegram_bot.py +++ b/tests/components/telegram_bot/test_telegram_bot.py @@ -1,17 +1,33 @@ """Tests for the telegram_bot component.""" +import base64 +import io from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, mock_open, patch import pytest from telegram import Update from telegram.error import NetworkError, RetryAfter, TelegramError, TimedOut from homeassistant.components.telegram_bot import ( + ATTR_FILE, + ATTR_LATITUDE, + ATTR_LONGITUDE, ATTR_MESSAGE, ATTR_MESSAGE_THREAD_ID, + ATTR_OPTIONS, + ATTR_QUESTION, + ATTR_STICKER_ID, DOMAIN, + SERVICE_SEND_ANIMATION, + SERVICE_SEND_DOCUMENT, + SERVICE_SEND_LOCATION, SERVICE_SEND_MESSAGE, + SERVICE_SEND_PHOTO, + SERVICE_SEND_POLL, + SERVICE_SEND_STICKER, + SERVICE_SEND_VIDEO, + SERVICE_SEND_VOICE, ) from homeassistant.components.telegram_bot.webhooks import TELEGRAM_WEBHOOK_URL from homeassistant.const import EVENT_HOMEASSISTANT_START @@ -32,23 +48,125 @@ async def test_polling_platform_init(hass: HomeAssistant, polling_platform) -> N assert hass.services.has_service(DOMAIN, SERVICE_SEND_MESSAGE) is True -async def test_send_message(hass: HomeAssistant, webhook_platform) -> None: - """Test the send_message service.""" +@pytest.mark.parametrize( + ("service", "input"), + [ + ( + SERVICE_SEND_MESSAGE, + {ATTR_MESSAGE: "test_message", ATTR_MESSAGE_THREAD_ID: "123"}, + ), + ( + SERVICE_SEND_STICKER, + { + ATTR_STICKER_ID: "1", + ATTR_MESSAGE_THREAD_ID: "123", + }, + ), + ( + SERVICE_SEND_POLL, + { + ATTR_QUESTION: "Question", + ATTR_OPTIONS: ["Yes", "No"], + }, + ), + ( + SERVICE_SEND_LOCATION, + { + ATTR_MESSAGE: "test_message", + ATTR_MESSAGE_THREAD_ID: "123", + ATTR_LONGITUDE: "1.123", + ATTR_LATITUDE: "1.123", + }, + ), + ], +) +async def test_send_message( + hass: HomeAssistant, webhook_platform, service: str, input: dict[str] +) -> None: + """Test the send_message service. Tests any service that does not require files to be sent.""" context = Context() events = async_capture_events(hass, "telegram_sent") - await hass.services.async_call( + response = await hass.services.async_call( DOMAIN, - SERVICE_SEND_MESSAGE, - {ATTR_MESSAGE: "test_message", ATTR_MESSAGE_THREAD_ID: "123"}, + service, + input, blocking=True, context=context, + return_response=True, ) await hass.async_block_till_done() assert len(events) == 1 assert events[0].context == context + assert len(response["chats"]) == 1 + assert (response["chats"][0]["message_id"]) == 12345 + + +@patch( + "builtins.open", + mock_open( + read_data=base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=" + ) + ), + create=True, +) +def _read_file_as_bytesio_mock(file_path): + """Convert file to BytesIO for testing.""" + _file = None + + with open(file_path, encoding="utf8") as file_handler: + _file = io.BytesIO(file_handler.read()) + + _file.name = "dummy" + _file.seek(0) + + return _file + + +@pytest.mark.parametrize( + "service", + [ + SERVICE_SEND_PHOTO, + SERVICE_SEND_ANIMATION, + SERVICE_SEND_VIDEO, + SERVICE_SEND_VOICE, + SERVICE_SEND_DOCUMENT, + ], +) +async def test_send_file(hass: HomeAssistant, webhook_platform, service: str) -> None: + """Test the send_file service (photo, animation, video, document...).""" + context = Context() + events = async_capture_events(hass, "telegram_sent") + + hass.config.allowlist_external_dirs.add("/media/") + + # Mock the file handler read with our base64 encoded dummy file + with patch( + "homeassistant.components.telegram_bot._read_file_as_bytesio", + _read_file_as_bytesio_mock, + ): + response = await hass.services.async_call( + DOMAIN, + service, + { + ATTR_FILE: "/media/dummy", + ATTR_MESSAGE_THREAD_ID: "123", + }, + blocking=True, + context=context, + return_response=True, + ) + await hass.async_block_till_done() + + assert len(events) == 1 + assert events[0].context == context + + assert len(response["chats"]) == 1 + assert (response["chats"][0]["message_id"]) == 12345 + async def test_send_message_thread(hass: HomeAssistant, webhook_platform) -> None: """Test the send_message service for threads."""