mirror of
https://github.com/home-assistant/core.git
synced 2025-06-24 09:01:55 +02:00
Add strict typing for Telegram bot integration (#147262)
add strict typing
This commit is contained in:
@ -503,6 +503,7 @@ homeassistant.components.tautulli.*
|
||||
homeassistant.components.tcp.*
|
||||
homeassistant.components.technove.*
|
||||
homeassistant.components.tedee.*
|
||||
homeassistant.components.telegram_bot.*
|
||||
homeassistant.components.text.*
|
||||
homeassistant.components.thethingsnetwork.*
|
||||
homeassistant.components.threshold.*
|
||||
|
@ -2,8 +2,10 @@
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Callable, Sequence
|
||||
import io
|
||||
import logging
|
||||
from ssl import SSLContext
|
||||
from types import MappingProxyType
|
||||
from typing import Any
|
||||
|
||||
@ -13,6 +15,7 @@ from telegram import (
|
||||
CallbackQuery,
|
||||
InlineKeyboardButton,
|
||||
InlineKeyboardMarkup,
|
||||
InputPollOption,
|
||||
Message,
|
||||
ReplyKeyboardMarkup,
|
||||
ReplyKeyboardRemove,
|
||||
@ -262,7 +265,9 @@ class TelegramNotificationService:
|
||||
|
||||
return allowed_chat_ids
|
||||
|
||||
def _get_msg_ids(self, msg_data, chat_id):
|
||||
def _get_msg_ids(
|
||||
self, msg_data: dict[str, Any], chat_id: int
|
||||
) -> tuple[Any | None, int | None]:
|
||||
"""Get the message id to edit.
|
||||
|
||||
This can be one of (message_id, inline_message_id) from a msg dict,
|
||||
@ -270,7 +275,8 @@ class TelegramNotificationService:
|
||||
**You can use 'last' as message_id** to edit
|
||||
the message last sent in the chat_id.
|
||||
"""
|
||||
message_id = inline_message_id = None
|
||||
message_id: Any | None = None
|
||||
inline_message_id: int | None = None
|
||||
if ATTR_MESSAGEID in msg_data:
|
||||
message_id = msg_data[ATTR_MESSAGEID]
|
||||
if (
|
||||
@ -283,7 +289,7 @@ class TelegramNotificationService:
|
||||
inline_message_id = msg_data["inline_message_id"]
|
||||
return message_id, inline_message_id
|
||||
|
||||
def _get_target_chat_ids(self, target):
|
||||
def _get_target_chat_ids(self, target: Any) -> list[int]:
|
||||
"""Validate chat_id targets or return default target (first).
|
||||
|
||||
:param target: optional list of integers ([12234, -12345])
|
||||
@ -302,10 +308,10 @@ class TelegramNotificationService:
|
||||
)
|
||||
return [default_user]
|
||||
|
||||
def _get_msg_kwargs(self, data):
|
||||
def _get_msg_kwargs(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Get parameters in message data kwargs."""
|
||||
|
||||
def _make_row_inline_keyboard(row_keyboard):
|
||||
def _make_row_inline_keyboard(row_keyboard: Any) -> list[InlineKeyboardButton]:
|
||||
"""Make a list of InlineKeyboardButtons.
|
||||
|
||||
It can accept:
|
||||
@ -350,7 +356,7 @@ class TelegramNotificationService:
|
||||
return buttons
|
||||
|
||||
# Defaults
|
||||
params = {
|
||||
params: dict[str, Any] = {
|
||||
ATTR_PARSER: self.parse_mode,
|
||||
ATTR_DISABLE_NOTIF: False,
|
||||
ATTR_DISABLE_WEB_PREV: None,
|
||||
@ -399,8 +405,14 @@ class TelegramNotificationService:
|
||||
return params
|
||||
|
||||
async def _send_msg(
|
||||
self, func_send, msg_error, message_tag, *args_msg, context=None, **kwargs_msg
|
||||
):
|
||||
self,
|
||||
func_send: Callable,
|
||||
msg_error: str,
|
||||
message_tag: str | None,
|
||||
*args_msg: Any,
|
||||
context: Context | None = None,
|
||||
**kwargs_msg: Any,
|
||||
) -> Any:
|
||||
"""Send one message."""
|
||||
try:
|
||||
out = await func_send(*args_msg, **kwargs_msg)
|
||||
@ -438,7 +450,13 @@ class TelegramNotificationService:
|
||||
return None
|
||||
return out
|
||||
|
||||
async def send_message(self, message="", target=None, context=None, **kwargs):
|
||||
async def send_message(
|
||||
self,
|
||||
message: str = "",
|
||||
target: Any = None,
|
||||
context: Context | None = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> dict[int, int]:
|
||||
"""Send a message to one or multiple pre-allowed chat IDs."""
|
||||
title = kwargs.get(ATTR_TITLE)
|
||||
text = f"{title}\n{message}" if title else message
|
||||
@ -465,12 +483,17 @@ class TelegramNotificationService:
|
||||
msg_ids[chat_id] = msg.id
|
||||
return msg_ids
|
||||
|
||||
async def delete_message(self, chat_id=None, context=None, **kwargs):
|
||||
async def delete_message(
|
||||
self,
|
||||
chat_id: int | None = None,
|
||||
context: Context | None = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> bool:
|
||||
"""Delete a previously sent message."""
|
||||
chat_id = self._get_target_chat_ids(chat_id)[0]
|
||||
message_id, _ = self._get_msg_ids(kwargs, chat_id)
|
||||
_LOGGER.debug("Delete message %s in chat ID %s", message_id, chat_id)
|
||||
deleted = await self._send_msg(
|
||||
deleted: bool = await self._send_msg(
|
||||
self.bot.delete_message,
|
||||
"Error deleting message",
|
||||
None,
|
||||
@ -484,7 +507,13 @@ class TelegramNotificationService:
|
||||
self._last_message_id[chat_id] -= 1
|
||||
return deleted
|
||||
|
||||
async def edit_message(self, type_edit, chat_id=None, context=None, **kwargs):
|
||||
async def edit_message(
|
||||
self,
|
||||
type_edit: str,
|
||||
chat_id: int | None = None,
|
||||
context: Context | None = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Edit a previously sent message."""
|
||||
chat_id = self._get_target_chat_ids(chat_id)[0]
|
||||
message_id, inline_message_id = self._get_msg_ids(kwargs, chat_id)
|
||||
@ -542,8 +571,13 @@ class TelegramNotificationService:
|
||||
)
|
||||
|
||||
async def answer_callback_query(
|
||||
self, message, callback_query_id, show_alert=False, context=None, **kwargs
|
||||
):
|
||||
self,
|
||||
message: str | None,
|
||||
callback_query_id: str,
|
||||
show_alert: bool = False,
|
||||
context: Context | None = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Answer a callback originated with a press in an inline keyboard."""
|
||||
params = self._get_msg_kwargs(kwargs)
|
||||
_LOGGER.debug(
|
||||
@ -564,16 +598,20 @@ class TelegramNotificationService:
|
||||
)
|
||||
|
||||
async def send_file(
|
||||
self, file_type=SERVICE_SEND_PHOTO, target=None, context=None, **kwargs
|
||||
):
|
||||
self,
|
||||
file_type: str,
|
||||
target: Any = None,
|
||||
context: Context | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[int, int]:
|
||||
"""Send a photo, sticker, video, or document."""
|
||||
params = self._get_msg_kwargs(kwargs)
|
||||
file_content = await load_data(
|
||||
self.hass,
|
||||
url=kwargs.get(ATTR_URL),
|
||||
filepath=kwargs.get(ATTR_FILE),
|
||||
username=kwargs.get(ATTR_USERNAME),
|
||||
password=kwargs.get(ATTR_PASSWORD),
|
||||
username=kwargs.get(ATTR_USERNAME, ""),
|
||||
password=kwargs.get(ATTR_PASSWORD, ""),
|
||||
authentication=kwargs.get(ATTR_AUTHENTICATION),
|
||||
verify_ssl=(
|
||||
get_default_context()
|
||||
@ -690,7 +728,12 @@ class TelegramNotificationService:
|
||||
|
||||
return msg_ids
|
||||
|
||||
async def send_sticker(self, target=None, context=None, **kwargs) -> dict:
|
||||
async def send_sticker(
|
||||
self,
|
||||
target: Any = None,
|
||||
context: Context | None = None,
|
||||
**kwargs: Any,
|
||||
) -> dict[int, int]:
|
||||
"""Send a sticker from a telegram sticker pack."""
|
||||
params = self._get_msg_kwargs(kwargs)
|
||||
stickerid = kwargs.get(ATTR_STICKER_ID)
|
||||
@ -713,11 +756,16 @@ class TelegramNotificationService:
|
||||
)
|
||||
msg_ids[chat_id] = msg.id
|
||||
return msg_ids
|
||||
return await self.send_file(SERVICE_SEND_STICKER, target, **kwargs)
|
||||
return await self.send_file(SERVICE_SEND_STICKER, target, context, **kwargs)
|
||||
|
||||
async def send_location(
|
||||
self, latitude, longitude, target=None, context=None, **kwargs
|
||||
):
|
||||
self,
|
||||
latitude: Any,
|
||||
longitude: Any,
|
||||
target: Any = None,
|
||||
context: Context | None = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> dict[int, int]:
|
||||
"""Send a location."""
|
||||
latitude = float(latitude)
|
||||
longitude = float(longitude)
|
||||
@ -745,14 +793,14 @@ class TelegramNotificationService:
|
||||
|
||||
async def send_poll(
|
||||
self,
|
||||
question,
|
||||
options,
|
||||
is_anonymous,
|
||||
allows_multiple_answers,
|
||||
target=None,
|
||||
context=None,
|
||||
**kwargs,
|
||||
):
|
||||
question: str,
|
||||
options: Sequence[str | InputPollOption],
|
||||
is_anonymous: bool | None,
|
||||
allows_multiple_answers: bool | None,
|
||||
target: Any = None,
|
||||
context: Context | None = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> dict[int, int]:
|
||||
"""Send a poll."""
|
||||
params = self._get_msg_kwargs(kwargs)
|
||||
openperiod = kwargs.get(ATTR_OPEN_PERIOD)
|
||||
@ -778,7 +826,12 @@ class TelegramNotificationService:
|
||||
msg_ids[chat_id] = msg.id
|
||||
return msg_ids
|
||||
|
||||
async def leave_chat(self, chat_id=None, context=None, **kwargs):
|
||||
async def leave_chat(
|
||||
self,
|
||||
chat_id: Any = None,
|
||||
context: Context | None = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Remove bot from chat."""
|
||||
chat_id = self._get_target_chat_ids(chat_id)[0]
|
||||
_LOGGER.debug("Leave from chat ID %s", chat_id)
|
||||
@ -792,7 +845,7 @@ class TelegramNotificationService:
|
||||
reaction: str,
|
||||
is_big: bool = False,
|
||||
context: Context | None = None,
|
||||
**kwargs,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Set the bot's reaction for a given message."""
|
||||
chat_id = self._get_target_chat_ids(chat_id)[0]
|
||||
@ -878,19 +931,19 @@ def initialize_bot(hass: HomeAssistant, p_config: MappingProxyType[str, Any]) ->
|
||||
|
||||
async def load_data(
|
||||
hass: HomeAssistant,
|
||||
url=None,
|
||||
filepath=None,
|
||||
username=None,
|
||||
password=None,
|
||||
authentication=None,
|
||||
num_retries=5,
|
||||
verify_ssl=None,
|
||||
):
|
||||
url: str | None,
|
||||
filepath: str | None,
|
||||
username: str,
|
||||
password: str,
|
||||
authentication: str | None,
|
||||
verify_ssl: SSLContext,
|
||||
num_retries: int = 5,
|
||||
) -> io.BytesIO:
|
||||
"""Load data into ByteIO/File container from a source."""
|
||||
if url is not None:
|
||||
# Load data from URL
|
||||
params: dict[str, Any] = {}
|
||||
headers = {}
|
||||
headers: dict[str, str] = {}
|
||||
_validate_credentials_input(authentication, username, password)
|
||||
if authentication == HTTP_BEARER_AUTHENTICATION:
|
||||
headers = {"Authorization": f"Bearer {password}"}
|
||||
@ -963,7 +1016,7 @@ def _validate_credentials_input(
|
||||
) -> None:
|
||||
if (
|
||||
authentication in (HTTP_BASIC_AUTHENTICATION, HTTP_DIGEST_AUTHENTICATION)
|
||||
and username is None
|
||||
and not username
|
||||
):
|
||||
raise ServiceValidationError(
|
||||
"Username is required.",
|
||||
@ -979,7 +1032,7 @@ def _validate_credentials_input(
|
||||
HTTP_BEARER_AUTHENTICATION,
|
||||
HTTP_BEARER_AUTHENTICATION,
|
||||
)
|
||||
and password is None
|
||||
and not password
|
||||
):
|
||||
raise ServiceValidationError(
|
||||
"Password is required.",
|
||||
|
@ -64,16 +64,18 @@ class PollBot(BaseTelegramBot):
|
||||
"""Shutdown the app."""
|
||||
await self.stop_polling()
|
||||
|
||||
async def start_polling(self, event=None):
|
||||
async def start_polling(self) -> None:
|
||||
"""Start the polling task."""
|
||||
_LOGGER.debug("Starting polling")
|
||||
await self.application.initialize()
|
||||
await self.application.updater.start_polling(error_callback=error_callback)
|
||||
if self.application.updater:
|
||||
await self.application.updater.start_polling(error_callback=error_callback)
|
||||
await self.application.start()
|
||||
|
||||
async def stop_polling(self, event=None):
|
||||
async def stop_polling(self) -> None:
|
||||
"""Stop the polling task."""
|
||||
_LOGGER.debug("Stopping polling")
|
||||
await self.application.updater.stop()
|
||||
if self.application.updater:
|
||||
await self.application.updater.stop()
|
||||
await self.application.stop()
|
||||
await self.application.shutdown()
|
||||
|
@ -6,11 +6,12 @@ import logging
|
||||
import secrets
|
||||
import string
|
||||
|
||||
from aiohttp.web_response import Response
|
||||
from telegram import Bot, Update
|
||||
from telegram.error import NetworkError, TelegramError
|
||||
from telegram.ext import ApplicationBuilder, TypeHandler
|
||||
from telegram.ext import Application, ApplicationBuilder, TypeHandler
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.http import HomeAssistantRequest, HomeAssistantView
|
||||
from homeassistant.const import CONF_URL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
@ -87,7 +88,7 @@ class PushBot(BaseTelegramBot):
|
||||
"""Shutdown the app."""
|
||||
await self.stop_application()
|
||||
|
||||
async def _try_to_set_webhook(self):
|
||||
async def _try_to_set_webhook(self) -> bool:
|
||||
_LOGGER.debug("Registering webhook URL: %s", self.webhook_url)
|
||||
retry_num = 0
|
||||
while retry_num < 3:
|
||||
@ -103,12 +104,12 @@ class PushBot(BaseTelegramBot):
|
||||
|
||||
return False
|
||||
|
||||
async def start_application(self):
|
||||
async def start_application(self) -> None:
|
||||
"""Handle starting the Application object."""
|
||||
await self.application.initialize()
|
||||
await self.application.start()
|
||||
|
||||
async def register_webhook(self):
|
||||
async def register_webhook(self) -> bool:
|
||||
"""Query telegram and register the URL for our webhook."""
|
||||
current_status = await self.bot.get_webhook_info()
|
||||
# Some logging of Bot current status:
|
||||
@ -123,13 +124,13 @@ class PushBot(BaseTelegramBot):
|
||||
|
||||
return True
|
||||
|
||||
async def stop_application(self, event=None):
|
||||
async def stop_application(self) -> None:
|
||||
"""Handle gracefully stopping the Application object."""
|
||||
await self.deregister_webhook()
|
||||
await self.application.stop()
|
||||
await self.application.shutdown()
|
||||
|
||||
async def deregister_webhook(self):
|
||||
async def deregister_webhook(self) -> None:
|
||||
"""Query telegram and deregister the URL for our webhook."""
|
||||
_LOGGER.debug("Deregistering webhook URL")
|
||||
try:
|
||||
@ -149,7 +150,7 @@ class PushBotView(HomeAssistantView):
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
bot: Bot,
|
||||
application,
|
||||
application: Application,
|
||||
trusted_networks: list[IPv4Network],
|
||||
secret_token: str,
|
||||
) -> None:
|
||||
@ -160,15 +161,16 @@ class PushBotView(HomeAssistantView):
|
||||
self.trusted_networks = trusted_networks
|
||||
self.secret_token = secret_token
|
||||
|
||||
async def post(self, request):
|
||||
async def post(self, request: HomeAssistantRequest) -> Response | None:
|
||||
"""Accept the POST from telegram."""
|
||||
real_ip = ip_address(request.remote)
|
||||
if not any(real_ip in net for net in self.trusted_networks):
|
||||
_LOGGER.warning("Access denied from %s", real_ip)
|
||||
if not request.remote or not any(
|
||||
ip_address(request.remote) in net for net in self.trusted_networks
|
||||
):
|
||||
_LOGGER.warning("Access denied from %s", request.remote)
|
||||
return self.json_message("Access denied", HTTPStatus.UNAUTHORIZED)
|
||||
secret_token_header = request.headers.get("X-Telegram-Bot-Api-Secret-Token")
|
||||
if secret_token_header is None or self.secret_token != secret_token_header:
|
||||
_LOGGER.warning("Invalid secret token from %s", real_ip)
|
||||
_LOGGER.warning("Invalid secret token from %s", request.remote)
|
||||
return self.json_message("Access denied", HTTPStatus.UNAUTHORIZED)
|
||||
|
||||
try:
|
||||
|
10
mypy.ini
generated
10
mypy.ini
generated
@ -4788,6 +4788,16 @@ disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.telegram_bot.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_subclassing_any = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_untyped_defs = true
|
||||
warn_return_any = true
|
||||
warn_unreachable = true
|
||||
|
||||
[mypy-homeassistant.components.text.*]
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
|
Reference in New Issue
Block a user