Add strict typing for Telegram bot integration (#147262)

add strict typing
This commit is contained in:
hanwg
2025-06-24 04:22:00 +08:00
committed by GitHub
parent 8b6205be25
commit dc948e3b6c
5 changed files with 128 additions and 60 deletions

View File

@ -503,6 +503,7 @@ homeassistant.components.tautulli.*
homeassistant.components.tcp.*
homeassistant.components.technove.*
homeassistant.components.tedee.*
homeassistant.components.telegram_bot.*
homeassistant.components.text.*
homeassistant.components.thethingsnetwork.*
homeassistant.components.threshold.*

View File

@ -2,8 +2,10 @@
from abc import abstractmethod
import asyncio
from collections.abc import Callable, Sequence
import io
import logging
from ssl import SSLContext
from types import MappingProxyType
from typing import Any
@ -13,6 +15,7 @@ from telegram import (
CallbackQuery,
InlineKeyboardButton,
InlineKeyboardMarkup,
InputPollOption,
Message,
ReplyKeyboardMarkup,
ReplyKeyboardRemove,
@ -262,7 +265,9 @@ class TelegramNotificationService:
return allowed_chat_ids
def _get_msg_ids(self, msg_data, chat_id):
def _get_msg_ids(
self, msg_data: dict[str, Any], chat_id: int
) -> tuple[Any | None, int | None]:
"""Get the message id to edit.
This can be one of (message_id, inline_message_id) from a msg dict,
@ -270,7 +275,8 @@ class TelegramNotificationService:
**You can use 'last' as message_id** to edit
the message last sent in the chat_id.
"""
message_id = inline_message_id = None
message_id: Any | None = None
inline_message_id: int | None = None
if ATTR_MESSAGEID in msg_data:
message_id = msg_data[ATTR_MESSAGEID]
if (
@ -283,7 +289,7 @@ class TelegramNotificationService:
inline_message_id = msg_data["inline_message_id"]
return message_id, inline_message_id
def _get_target_chat_ids(self, target):
def _get_target_chat_ids(self, target: Any) -> list[int]:
"""Validate chat_id targets or return default target (first).
:param target: optional list of integers ([12234, -12345])
@ -302,10 +308,10 @@ class TelegramNotificationService:
)
return [default_user]
def _get_msg_kwargs(self, data):
def _get_msg_kwargs(self, data: dict[str, Any]) -> dict[str, Any]:
"""Get parameters in message data kwargs."""
def _make_row_inline_keyboard(row_keyboard):
def _make_row_inline_keyboard(row_keyboard: Any) -> list[InlineKeyboardButton]:
"""Make a list of InlineKeyboardButtons.
It can accept:
@ -350,7 +356,7 @@ class TelegramNotificationService:
return buttons
# Defaults
params = {
params: dict[str, Any] = {
ATTR_PARSER: self.parse_mode,
ATTR_DISABLE_NOTIF: False,
ATTR_DISABLE_WEB_PREV: None,
@ -399,8 +405,14 @@ class TelegramNotificationService:
return params
async def _send_msg(
self, func_send, msg_error, message_tag, *args_msg, context=None, **kwargs_msg
):
self,
func_send: Callable,
msg_error: str,
message_tag: str | None,
*args_msg: Any,
context: Context | None = None,
**kwargs_msg: Any,
) -> Any:
"""Send one message."""
try:
out = await func_send(*args_msg, **kwargs_msg)
@ -438,7 +450,13 @@ class TelegramNotificationService:
return None
return out
async def send_message(self, message="", target=None, context=None, **kwargs):
async def send_message(
self,
message: str = "",
target: Any = None,
context: Context | None = None,
**kwargs: dict[str, Any],
) -> dict[int, int]:
"""Send a message to one or multiple pre-allowed chat IDs."""
title = kwargs.get(ATTR_TITLE)
text = f"{title}\n{message}" if title else message
@ -465,12 +483,17 @@ class TelegramNotificationService:
msg_ids[chat_id] = msg.id
return msg_ids
async def delete_message(self, chat_id=None, context=None, **kwargs):
async def delete_message(
self,
chat_id: int | None = None,
context: Context | None = None,
**kwargs: dict[str, Any],
) -> bool:
"""Delete a previously sent message."""
chat_id = self._get_target_chat_ids(chat_id)[0]
message_id, _ = self._get_msg_ids(kwargs, chat_id)
_LOGGER.debug("Delete message %s in chat ID %s", message_id, chat_id)
deleted = await self._send_msg(
deleted: bool = await self._send_msg(
self.bot.delete_message,
"Error deleting message",
None,
@ -484,7 +507,13 @@ class TelegramNotificationService:
self._last_message_id[chat_id] -= 1
return deleted
async def edit_message(self, type_edit, chat_id=None, context=None, **kwargs):
async def edit_message(
self,
type_edit: str,
chat_id: int | None = None,
context: Context | None = None,
**kwargs: dict[str, Any],
) -> Any:
"""Edit a previously sent message."""
chat_id = self._get_target_chat_ids(chat_id)[0]
message_id, inline_message_id = self._get_msg_ids(kwargs, chat_id)
@ -542,8 +571,13 @@ class TelegramNotificationService:
)
async def answer_callback_query(
self, message, callback_query_id, show_alert=False, context=None, **kwargs
):
self,
message: str | None,
callback_query_id: str,
show_alert: bool = False,
context: Context | None = None,
**kwargs: dict[str, Any],
) -> None:
"""Answer a callback originated with a press in an inline keyboard."""
params = self._get_msg_kwargs(kwargs)
_LOGGER.debug(
@ -564,16 +598,20 @@ class TelegramNotificationService:
)
async def send_file(
self, file_type=SERVICE_SEND_PHOTO, target=None, context=None, **kwargs
):
self,
file_type: str,
target: Any = None,
context: Context | None = None,
**kwargs: Any,
) -> dict[int, int]:
"""Send a photo, sticker, video, or document."""
params = self._get_msg_kwargs(kwargs)
file_content = await load_data(
self.hass,
url=kwargs.get(ATTR_URL),
filepath=kwargs.get(ATTR_FILE),
username=kwargs.get(ATTR_USERNAME),
password=kwargs.get(ATTR_PASSWORD),
username=kwargs.get(ATTR_USERNAME, ""),
password=kwargs.get(ATTR_PASSWORD, ""),
authentication=kwargs.get(ATTR_AUTHENTICATION),
verify_ssl=(
get_default_context()
@ -690,7 +728,12 @@ class TelegramNotificationService:
return msg_ids
async def send_sticker(self, target=None, context=None, **kwargs) -> dict:
async def send_sticker(
self,
target: Any = None,
context: Context | None = None,
**kwargs: Any,
) -> dict[int, int]:
"""Send a sticker from a telegram sticker pack."""
params = self._get_msg_kwargs(kwargs)
stickerid = kwargs.get(ATTR_STICKER_ID)
@ -713,11 +756,16 @@ class TelegramNotificationService:
)
msg_ids[chat_id] = msg.id
return msg_ids
return await self.send_file(SERVICE_SEND_STICKER, target, **kwargs)
return await self.send_file(SERVICE_SEND_STICKER, target, context, **kwargs)
async def send_location(
self, latitude, longitude, target=None, context=None, **kwargs
):
self,
latitude: Any,
longitude: Any,
target: Any = None,
context: Context | None = None,
**kwargs: dict[str, Any],
) -> dict[int, int]:
"""Send a location."""
latitude = float(latitude)
longitude = float(longitude)
@ -745,14 +793,14 @@ class TelegramNotificationService:
async def send_poll(
self,
question,
options,
is_anonymous,
allows_multiple_answers,
target=None,
context=None,
**kwargs,
):
question: str,
options: Sequence[str | InputPollOption],
is_anonymous: bool | None,
allows_multiple_answers: bool | None,
target: Any = None,
context: Context | None = None,
**kwargs: dict[str, Any],
) -> dict[int, int]:
"""Send a poll."""
params = self._get_msg_kwargs(kwargs)
openperiod = kwargs.get(ATTR_OPEN_PERIOD)
@ -778,7 +826,12 @@ class TelegramNotificationService:
msg_ids[chat_id] = msg.id
return msg_ids
async def leave_chat(self, chat_id=None, context=None, **kwargs):
async def leave_chat(
self,
chat_id: Any = None,
context: Context | None = None,
**kwargs: dict[str, Any],
) -> Any:
"""Remove bot from chat."""
chat_id = self._get_target_chat_ids(chat_id)[0]
_LOGGER.debug("Leave from chat ID %s", chat_id)
@ -792,7 +845,7 @@ class TelegramNotificationService:
reaction: str,
is_big: bool = False,
context: Context | None = None,
**kwargs,
**kwargs: dict[str, Any],
) -> None:
"""Set the bot's reaction for a given message."""
chat_id = self._get_target_chat_ids(chat_id)[0]
@ -878,19 +931,19 @@ def initialize_bot(hass: HomeAssistant, p_config: MappingProxyType[str, Any]) ->
async def load_data(
hass: HomeAssistant,
url=None,
filepath=None,
username=None,
password=None,
authentication=None,
num_retries=5,
verify_ssl=None,
):
url: str | None,
filepath: str | None,
username: str,
password: str,
authentication: str | None,
verify_ssl: SSLContext,
num_retries: int = 5,
) -> io.BytesIO:
"""Load data into ByteIO/File container from a source."""
if url is not None:
# Load data from URL
params: dict[str, Any] = {}
headers = {}
headers: dict[str, str] = {}
_validate_credentials_input(authentication, username, password)
if authentication == HTTP_BEARER_AUTHENTICATION:
headers = {"Authorization": f"Bearer {password}"}
@ -963,7 +1016,7 @@ def _validate_credentials_input(
) -> None:
if (
authentication in (HTTP_BASIC_AUTHENTICATION, HTTP_DIGEST_AUTHENTICATION)
and username is None
and not username
):
raise ServiceValidationError(
"Username is required.",
@ -979,7 +1032,7 @@ def _validate_credentials_input(
HTTP_BEARER_AUTHENTICATION,
HTTP_BEARER_AUTHENTICATION,
)
and password is None
and not password
):
raise ServiceValidationError(
"Password is required.",

View File

@ -64,16 +64,18 @@ class PollBot(BaseTelegramBot):
"""Shutdown the app."""
await self.stop_polling()
async def start_polling(self, event=None):
async def start_polling(self) -> None:
"""Start the polling task."""
_LOGGER.debug("Starting polling")
await self.application.initialize()
await self.application.updater.start_polling(error_callback=error_callback)
if self.application.updater:
await self.application.updater.start_polling(error_callback=error_callback)
await self.application.start()
async def stop_polling(self, event=None):
async def stop_polling(self) -> None:
"""Stop the polling task."""
_LOGGER.debug("Stopping polling")
await self.application.updater.stop()
if self.application.updater:
await self.application.updater.stop()
await self.application.stop()
await self.application.shutdown()

View File

@ -6,11 +6,12 @@ import logging
import secrets
import string
from aiohttp.web_response import Response
from telegram import Bot, Update
from telegram.error import NetworkError, TelegramError
from telegram.ext import ApplicationBuilder, TypeHandler
from telegram.ext import Application, ApplicationBuilder, TypeHandler
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http import HomeAssistantRequest, HomeAssistantView
from homeassistant.const import CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
@ -87,7 +88,7 @@ class PushBot(BaseTelegramBot):
"""Shutdown the app."""
await self.stop_application()
async def _try_to_set_webhook(self):
async def _try_to_set_webhook(self) -> bool:
_LOGGER.debug("Registering webhook URL: %s", self.webhook_url)
retry_num = 0
while retry_num < 3:
@ -103,12 +104,12 @@ class PushBot(BaseTelegramBot):
return False
async def start_application(self):
async def start_application(self) -> None:
"""Handle starting the Application object."""
await self.application.initialize()
await self.application.start()
async def register_webhook(self):
async def register_webhook(self) -> bool:
"""Query telegram and register the URL for our webhook."""
current_status = await self.bot.get_webhook_info()
# Some logging of Bot current status:
@ -123,13 +124,13 @@ class PushBot(BaseTelegramBot):
return True
async def stop_application(self, event=None):
async def stop_application(self) -> None:
"""Handle gracefully stopping the Application object."""
await self.deregister_webhook()
await self.application.stop()
await self.application.shutdown()
async def deregister_webhook(self):
async def deregister_webhook(self) -> None:
"""Query telegram and deregister the URL for our webhook."""
_LOGGER.debug("Deregistering webhook URL")
try:
@ -149,7 +150,7 @@ class PushBotView(HomeAssistantView):
self,
hass: HomeAssistant,
bot: Bot,
application,
application: Application,
trusted_networks: list[IPv4Network],
secret_token: str,
) -> None:
@ -160,15 +161,16 @@ class PushBotView(HomeAssistantView):
self.trusted_networks = trusted_networks
self.secret_token = secret_token
async def post(self, request):
async def post(self, request: HomeAssistantRequest) -> Response | None:
"""Accept the POST from telegram."""
real_ip = ip_address(request.remote)
if not any(real_ip in net for net in self.trusted_networks):
_LOGGER.warning("Access denied from %s", real_ip)
if not request.remote or not any(
ip_address(request.remote) in net for net in self.trusted_networks
):
_LOGGER.warning("Access denied from %s", request.remote)
return self.json_message("Access denied", HTTPStatus.UNAUTHORIZED)
secret_token_header = request.headers.get("X-Telegram-Bot-Api-Secret-Token")
if secret_token_header is None or self.secret_token != secret_token_header:
_LOGGER.warning("Invalid secret token from %s", real_ip)
_LOGGER.warning("Invalid secret token from %s", request.remote)
return self.json_message("Access denied", HTTPStatus.UNAUTHORIZED)
try:

10
mypy.ini generated
View File

@ -4788,6 +4788,16 @@ disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.telegram_bot.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.text.*]
check_untyped_defs = true
disallow_incomplete_defs = true