Files

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

372 lines
11 KiB
Python
Raw Permalink Normal View History

2019-04-03 17:40:03 +02:00
"""Slack platform for notify component."""
import asyncio
2015-08-01 06:45:41 +10:00
import logging
import os
2025-01-19 15:09:04 -05:00
from typing import Any, TypedDict, cast
from urllib.parse import urlparse
2015-08-01 06:45:41 +10:00
2025-01-19 15:09:04 -05:00
from aiohttp import BasicAuth
from aiohttp.client_exceptions import ClientError
2025-04-11 20:27:45 +02:00
from slack_sdk.errors import SlackApiError
2025-01-19 15:09:04 -05:00
from slack_sdk.web.async_client import AsyncWebClient
import voluptuous as vol
2019-03-27 20:36:13 -07:00
from homeassistant.components.notify import (
ATTR_DATA,
ATTR_TARGET,
ATTR_TITLE,
2017-08-21 10:23:29 +02:00
BaseNotificationService,
)
2023-06-06 16:01:40 +02:00
from homeassistant.const import ATTR_ICON, CONF_PATH
from homeassistant.core import HomeAssistant, callback
2021-07-15 06:44:57 +02:00
from homeassistant.helpers import aiohttp_client, config_validation as cv, template
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
2015-08-01 06:45:41 +10:00
from .const import (
ATTR_BLOCKS,
ATTR_BLOCKS_TEMPLATE,
ATTR_FILE,
ATTR_PASSWORD,
ATTR_PATH,
2023-09-21 17:06:55 +08:00
ATTR_THREAD_TS,
ATTR_URL,
ATTR_USERNAME,
CONF_DEFAULT_CHANNEL,
DATA_CLIENT,
2022-11-27 15:31:38 -05:00
SLACK_DATA,
)
2025-01-19 15:09:04 -05:00
from .utils import upload_file_to_slack
2015-08-01 06:45:41 +10:00
_LOGGER = logging.getLogger(__name__)
FILE_PATH_SCHEMA = vol.Schema({vol.Required(CONF_PATH): cv.isfile})
FILE_URL_SCHEMA = vol.Schema(
{
vol.Required(ATTR_URL): cv.url,
vol.Inclusive(ATTR_USERNAME, "credentials"): cv.string,
vol.Inclusive(ATTR_PASSWORD, "credentials"): cv.string,
}
)
DATA_FILE_SCHEMA = vol.Schema(
2023-09-21 17:06:55 +08:00
{
vol.Required(ATTR_FILE): vol.Any(FILE_PATH_SCHEMA, FILE_URL_SCHEMA),
vol.Optional(ATTR_THREAD_TS): cv.string,
}
)
DATA_TEXT_ONLY_SCHEMA = vol.Schema(
{
vol.Optional(ATTR_USERNAME): cv.string,
vol.Optional(ATTR_ICON): cv.string,
vol.Optional(ATTR_BLOCKS): list,
vol.Optional(ATTR_BLOCKS_TEMPLATE): list,
2023-09-21 17:06:55 +08:00
vol.Optional(ATTR_THREAD_TS): cv.string,
}
)
DATA_SCHEMA = vol.All(
cv.ensure_list, [vol.Any(DATA_FILE_SCHEMA, DATA_TEXT_ONLY_SCHEMA)]
)
2015-08-01 06:45:41 +10:00
2021-01-26 01:03:12 +01:00
class AuthDictT(TypedDict, total=False):
"""Type for auth request data."""
auth: BasicAuth
2023-09-21 17:06:55 +08:00
class FormDataT(TypedDict, total=False):
2021-01-26 01:03:12 +01:00
"""Type for form data, file upload."""
channels: str
filename: str
initial_comment: str
title: str
token: str
2023-09-21 17:06:55 +08:00
thread_ts: str # Optional key
2021-01-26 01:03:12 +01:00
class MessageT(TypedDict, total=False):
"""Type for message data."""
link_names: bool
text: str
username: str # Optional key
icon_url: str # Optional key
icon_emoji: str # Optional key
2021-03-01 09:09:01 +01:00
blocks: list[Any] # Optional key
2023-09-21 17:06:55 +08:00
thread_ts: str # Optional key
2021-01-26 01:03:12 +01:00
async def async_get_service(
hass: HomeAssistant,
2021-01-26 01:03:12 +01:00
config: ConfigType,
2021-03-01 09:09:01 +01:00
discovery_info: DiscoveryInfoType | None = None,
) -> SlackNotificationService | None:
"""Set up the Slack notification service."""
2023-06-06 16:01:40 +02:00
if discovery_info:
return SlackNotificationService(
hass,
discovery_info[SLACK_DATA][DATA_CLIENT],
discovery_info,
)
return None
@callback
2021-01-26 01:03:12 +01:00
def _async_get_filename_from_url(url: str) -> str:
"""Return the filename of a passed URL."""
parsed_url = urlparse(url)
return os.path.basename(parsed_url.path)
@callback
2021-03-01 09:09:01 +01:00
def _async_sanitize_channel_names(channel_list: list[str]) -> list[str]:
"""Remove any # symbols from a channel list."""
return [channel.lstrip("#") for channel in channel_list]
2015-08-01 06:45:41 +10:00
class SlackNotificationService(BaseNotificationService):
"""Define the Slack notification logic."""
2019-07-31 12:25:30 -07:00
2021-01-26 01:03:12 +01:00
def __init__(
self,
hass: HomeAssistant,
2025-01-19 15:09:04 -05:00
client: AsyncWebClient,
config: dict[str, str],
2021-01-26 01:03:12 +01:00
) -> None:
"""Initialize."""
self._hass = hass
self._client = client
self._config = config
2021-01-26 01:03:12 +01:00
async def _async_send_local_file_message(
self,
path: str,
2021-03-01 09:09:01 +01:00
targets: list[str],
2021-01-26 01:03:12 +01:00
message: str,
2021-03-01 09:09:01 +01:00
title: str | None,
2023-09-21 17:06:55 +08:00
thread_ts: str | None,
2021-01-26 01:03:12 +01:00
) -> None:
"""Upload a local file (with message) to Slack."""
if not self._hass.config.is_allowed_path(path):
_LOGGER.error("Path does not exist or is not allowed: %s", path)
return
2015-08-01 06:45:41 +10:00
parsed_url = urlparse(path)
filename = os.path.basename(parsed_url.path)
2015-08-01 06:45:41 +10:00
2025-01-19 15:09:04 -05:00
channel_ids = [await self._async_get_channel_id(target) for target in targets]
channel_ids = [cid for cid in channel_ids if cid] # Remove None values
if not channel_ids:
_LOGGER.error("No valid channel IDs resolved for targets: %s", targets)
return
await upload_file_to_slack(
client=self._client,
channel_ids=channel_ids,
file_content=None,
file_path=path,
filename=filename,
title=title,
message=message,
thread_ts=thread_ts,
)
async def _async_send_remote_file_message(
2021-01-26 01:03:12 +01:00
self,
url: str,
2021-03-01 09:09:01 +01:00
targets: list[str],
2021-01-26 01:03:12 +01:00
message: str,
2021-03-01 09:09:01 +01:00
title: str | None,
2023-09-21 17:06:55 +08:00
thread_ts: str | None,
2021-01-26 01:03:12 +01:00
*,
2021-03-01 09:09:01 +01:00
username: str | None = None,
password: str | None = None,
2021-01-26 01:03:12 +01:00
) -> None:
2025-01-19 15:09:04 -05:00
"""Upload a remote file (with message) to Slack."""
if not self._hass.config.is_allowed_external_url(url):
_LOGGER.error("URL is not allowed: %s", url)
return
filename = _async_get_filename_from_url(url)
2021-01-26 01:03:12 +01:00
session = aiohttp_client.async_get_clientsession(self._hass)
2025-01-19 15:09:04 -05:00
# Fetch the remote file
2021-01-26 01:03:12 +01:00
kwargs: AuthDictT = {}
2025-01-19 15:09:04 -05:00
if username and password:
kwargs = {"auth": BasicAuth(username, password=password)}
try:
2025-01-19 15:09:04 -05:00
async with session.get(url, **kwargs) as resp:
resp.raise_for_status()
file_content = await resp.read()
except ClientError as err:
2021-01-26 01:03:12 +01:00
_LOGGER.error("Error while retrieving %s: %r", url, err)
return
2025-01-19 15:09:04 -05:00
channel_ids = [await self._async_get_channel_id(target) for target in targets]
channel_ids = [cid for cid in channel_ids if cid] # Remove None values
2021-01-26 01:03:12 +01:00
2025-01-19 15:09:04 -05:00
if not channel_ids:
_LOGGER.error("No valid channel IDs resolved for targets: %s", targets)
return
2023-09-21 17:06:55 +08:00
2025-01-19 15:09:04 -05:00
await upload_file_to_slack(
client=self._client,
channel_ids=channel_ids,
file_content=file_content,
filename=filename,
title=title,
message=message,
thread_ts=thread_ts,
)
async def _async_send_text_only_message(
self,
2021-03-01 09:09:01 +01:00
targets: list[str],
2021-01-26 01:03:12 +01:00
message: str,
2021-03-01 09:09:01 +01:00
title: str | None,
2023-09-21 17:06:55 +08:00
thread_ts: str | None,
*,
2021-03-01 09:09:01 +01:00
username: str | None = None,
icon: str | None = None,
blocks: Any | None = None,
2021-01-26 01:03:12 +01:00
) -> None:
"""Send a text-only message."""
2021-01-26 01:03:12 +01:00
message_dict: MessageT = {"link_names": True, "text": message}
if username:
message_dict["username"] = username
2020-10-04 17:17:24 -04:00
if icon:
if icon.lower().startswith(("http://", "https://")):
2021-01-26 01:03:12 +01:00
message_dict["icon_url"] = icon
else:
2021-01-26 01:03:12 +01:00
message_dict["icon_emoji"] = icon
2020-08-14 08:07:04 -05:00
if blocks:
message_dict["blocks"] = blocks
2023-09-21 17:06:55 +08:00
if thread_ts:
message_dict["thread_ts"] = thread_ts
tasks = {
target: self._client.chat_postMessage(**message_dict, channel=target)
for target in targets
}
results = await asyncio.gather(*tasks.values(), return_exceptions=True)
2024-04-14 07:14:26 +02:00
for target, result in zip(tasks, results, strict=False):
if isinstance(result, SlackApiError):
_LOGGER.error(
2021-01-26 01:03:12 +01:00
"There was a Slack API error while sending to %s: %r",
target,
result,
)
2021-01-26 01:03:12 +01:00
elif isinstance(result, ClientError):
_LOGGER.error("Error while sending message to %s: %r", target, result)
2021-01-26 01:03:12 +01:00
async def async_send_message(self, message: str, **kwargs: Any) -> None:
"""Send a message to Slack."""
2021-01-26 01:03:12 +01:00
data = kwargs.get(ATTR_DATA) or {}
try:
DATA_SCHEMA(data)
except vol.Invalid as err:
_LOGGER.error("Invalid message data: %s", err)
data = {}
title = kwargs.get(ATTR_TITLE)
targets = _async_sanitize_channel_names(
kwargs.get(ATTR_TARGET, [self._config[CONF_DEFAULT_CHANNEL]])
)
# Message Type 1: A text-only message
if ATTR_FILE not in data:
if ATTR_BLOCKS_TEMPLATE in data:
value = cv.template_complex(data[ATTR_BLOCKS_TEMPLATE])
blocks = template.render_complex(value)
elif ATTR_BLOCKS in data:
blocks = data[ATTR_BLOCKS]
else:
blocks = None
return await self._async_send_text_only_message(
targets,
message,
title,
username=data.get(ATTR_USERNAME, self._config.get(ATTR_USERNAME)),
icon=data.get(ATTR_ICON, self._config.get(ATTR_ICON)),
2023-09-21 17:06:55 +08:00
thread_ts=data.get(ATTR_THREAD_TS),
blocks=blocks,
)
# Message Type 2: A message that uploads a remote file
if ATTR_URL in data[ATTR_FILE]:
return await self._async_send_remote_file_message(
data[ATTR_FILE][ATTR_URL],
targets,
message,
title,
2023-09-21 17:06:55 +08:00
thread_ts=data.get(ATTR_THREAD_TS),
username=data[ATTR_FILE].get(ATTR_USERNAME),
password=data[ATTR_FILE].get(ATTR_PASSWORD),
)
# Message Type 3: A message that uploads a local file
return await self._async_send_local_file_message(
2023-09-21 17:06:55 +08:00
data[ATTR_FILE][ATTR_PATH],
targets,
message,
title,
thread_ts=data.get(ATTR_THREAD_TS),
)
2025-01-19 15:09:04 -05:00
async def _async_get_channel_id(self, channel_name: str) -> str | None:
"""Get the Slack channel ID from the channel name.
This method retrieves the channel ID for a given Slack channel name by
querying the Slack API. It handles both public and private channels.
Including this so users can provide channel names instead of IDs.
Args:
channel_name (str): The name of the Slack channel.
Returns:
str | None: The ID of the Slack channel if found, otherwise None.
Raises:
SlackApiError: If there is an error while communicating with the Slack API.
"""
try:
# Remove # if present
channel_name = channel_name.lstrip("#")
# Get channel list
# Multiple types is not working. Tested here: https://api.slack.com/methods/conversations.list/test
# response = await self._client.conversations_list(types="public_channel,private_channel")
#
# Workaround for the types parameter not working
channels = []
for channel_type in ("public_channel", "private_channel"):
response = await self._client.conversations_list(types=channel_type)
channels.extend(response["channels"])
# Find channel ID
for channel in channels:
if channel["name"] == channel_name:
return cast(str, channel["id"])
_LOGGER.error("Channel %s not found", channel_name)
except SlackApiError as err:
_LOGGER.error("Error getting channel ID: %r", err)
return None