mirror of
https://github.com/home-assistant/core.git
synced 2026-03-21 02:04:51 +01:00
Compare commits
19 Commits
esphome-ff
...
use-unix-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0888dcc1da | ||
|
|
63bc4564b2 | ||
|
|
03817ccc07 | ||
|
|
f0c56d74a4 | ||
|
|
58d8824a44 | ||
|
|
d93b45fe35 | ||
|
|
88b9e6cd83 | ||
|
|
fdde93187a | ||
|
|
da29f06c2c | ||
|
|
cccb252b8d | ||
|
|
ea556d65cb | ||
|
|
f499a0b45b | ||
|
|
95d76e8e80 | ||
|
|
c3be74c1cd | ||
|
|
b6be7a12b1 | ||
|
|
72db92b17b | ||
|
|
c5889082c0 | ||
|
|
68d94badc6 | ||
|
|
275374ec0d |
@@ -10,6 +10,7 @@ from functools import partial
|
||||
from ipaddress import IPv4Network, IPv6Network, ip_network
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import socket
|
||||
import ssl
|
||||
from tempfile import NamedTemporaryFile
|
||||
@@ -33,6 +34,7 @@ from homeassistant.components.network import async_get_source_ip
|
||||
from homeassistant.const import (
|
||||
EVENT_HOMEASSISTANT_START,
|
||||
EVENT_HOMEASSISTANT_STOP,
|
||||
HASSIO_USER_NAME,
|
||||
SERVER_PORT,
|
||||
)
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
@@ -69,7 +71,7 @@ from .headers import setup_headers
|
||||
from .request_context import setup_request_context
|
||||
from .security_filter import setup_security_filter
|
||||
from .static import CACHE_HEADERS, CachingStaticResource
|
||||
from .web_runner import HomeAssistantTCPSite
|
||||
from .web_runner import HomeAssistantTCPSite, HomeAssistantUnixSite
|
||||
|
||||
CONF_SERVER_HOST: Final = "server_host"
|
||||
CONF_SERVER_PORT: Final = "server_port"
|
||||
@@ -235,6 +237,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
||||
source_ip_task = create_eager_task(async_get_source_ip(hass))
|
||||
|
||||
unix_socket_path: Path | None = None
|
||||
if socket_env := os.environ.get("SUPERVISOR_CORE_API_SOCKET"):
|
||||
socket_path = Path(socket_env)
|
||||
if socket_path.is_absolute():
|
||||
unix_socket_path = socket_path
|
||||
else:
|
||||
_LOGGER.error(
|
||||
"Invalid unix socket path %s: path must be absolute", socket_env
|
||||
)
|
||||
|
||||
server = HomeAssistantHTTP(
|
||||
hass,
|
||||
server_host=server_host,
|
||||
@@ -244,6 +256,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
ssl_key=ssl_key,
|
||||
trusted_proxies=trusted_proxies,
|
||||
ssl_profile=ssl_profile,
|
||||
unix_socket_path=unix_socket_path,
|
||||
)
|
||||
await server.async_initialize(
|
||||
cors_origins=cors_origins,
|
||||
@@ -267,6 +280,21 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
||||
async_when_setup_or_start(hass, "frontend", start_server)
|
||||
|
||||
if server.unix_socket_path is not None:
|
||||
|
||||
async def start_unix_socket(*_: Any) -> None:
|
||||
"""Start the Unix socket after the Supervisor user is available."""
|
||||
if any(
|
||||
user
|
||||
for user in await hass.auth.async_get_users()
|
||||
if user.system_generated and user.name == HASSIO_USER_NAME
|
||||
):
|
||||
await server.async_start_unix_socket()
|
||||
else:
|
||||
_LOGGER.error("Supervisor user not found; not starting Unix socket")
|
||||
|
||||
async_when_setup_or_start(hass, "hassio", start_unix_socket)
|
||||
|
||||
hass.http = server
|
||||
|
||||
local_ip = await source_ip_task
|
||||
@@ -366,6 +394,7 @@ class HomeAssistantHTTP:
|
||||
server_port: int,
|
||||
trusted_proxies: list[IPv4Network | IPv6Network],
|
||||
ssl_profile: str,
|
||||
unix_socket_path: Path | None = None,
|
||||
) -> None:
|
||||
"""Initialize the HTTP Home Assistant server."""
|
||||
self.app = HomeAssistantApplication(
|
||||
@@ -384,8 +413,10 @@ class HomeAssistantHTTP:
|
||||
self.server_port = server_port
|
||||
self.trusted_proxies = trusted_proxies
|
||||
self.ssl_profile = ssl_profile
|
||||
self.unix_socket_path = unix_socket_path
|
||||
self.runner: web.AppRunner | None = None
|
||||
self.site: HomeAssistantTCPSite | None = None
|
||||
self.unix_site: HomeAssistantUnixSite | None = None
|
||||
self.context: ssl.SSLContext | None = None
|
||||
|
||||
async def async_initialize(
|
||||
@@ -610,6 +641,29 @@ class HomeAssistantHTTP:
|
||||
context.load_cert_chain(cert_pem.name, key_pem.name)
|
||||
return context
|
||||
|
||||
async def async_start_unix_socket(self) -> None:
|
||||
"""Start listening on the Unix socket.
|
||||
|
||||
This is called separately from start() to delay serving the Unix
|
||||
socket until the Supervisor user exists (created by the hassio
|
||||
integration). Without this delay, Supervisor could connect before
|
||||
its user is available and receive 401 responses it won't retry.
|
||||
"""
|
||||
if self.unix_socket_path is None or self.runner is None:
|
||||
return
|
||||
self.unix_site = HomeAssistantUnixSite(self.runner, self.unix_socket_path)
|
||||
try:
|
||||
await self.unix_site.start()
|
||||
except OSError as error:
|
||||
_LOGGER.error(
|
||||
"Failed to create HTTP server on unix socket %s: %s",
|
||||
self.unix_socket_path,
|
||||
error,
|
||||
)
|
||||
self.unix_site = None
|
||||
else:
|
||||
_LOGGER.info("Now listening on unix socket %s", self.unix_socket_path)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the aiohttp server."""
|
||||
# Aiohttp freezes apps after start so that no changes can be made.
|
||||
@@ -637,6 +691,19 @@ class HomeAssistantHTTP:
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the aiohttp server."""
|
||||
if self.unix_site is not None:
|
||||
await self.unix_site.stop()
|
||||
if self.unix_socket_path is not None:
|
||||
try:
|
||||
await self.hass.async_add_executor_job(
|
||||
self.unix_socket_path.unlink, True
|
||||
)
|
||||
except OSError as err:
|
||||
_LOGGER.warning(
|
||||
"Could not remove unix socket %s: %s",
|
||||
self.unix_socket_path,
|
||||
err,
|
||||
)
|
||||
if self.site is not None:
|
||||
await self.site.stop()
|
||||
if self.runner is not None:
|
||||
|
||||
@@ -11,7 +11,13 @@ import time
|
||||
from typing import Any, Final
|
||||
|
||||
from aiohttp import hdrs
|
||||
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||
from aiohttp.web import (
|
||||
Application,
|
||||
HTTPInternalServerError,
|
||||
Request,
|
||||
StreamResponse,
|
||||
middleware,
|
||||
)
|
||||
import jwt
|
||||
from jwt import api_jws
|
||||
from yarl import URL
|
||||
@@ -20,6 +26,7 @@ from homeassistant.auth import jwt_wrapper
|
||||
from homeassistant.auth.const import GROUP_ID_READ_ONLY
|
||||
from homeassistant.auth.models import User
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.const import HASSIO_USER_NAME
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.http import current_request
|
||||
from homeassistant.helpers.json import json_bytes
|
||||
@@ -27,7 +34,12 @@ from homeassistant.helpers.network import is_cloud_connection
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.util.network import is_local
|
||||
|
||||
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
|
||||
from .const import (
|
||||
KEY_AUTHENTICATED,
|
||||
KEY_HASS_REFRESH_TOKEN_ID,
|
||||
KEY_HASS_USER,
|
||||
is_unix_socket_request,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -117,7 +129,7 @@ def async_user_not_allowed_do_auth(
|
||||
return "User cannot authenticate remotely"
|
||||
|
||||
|
||||
async def async_setup_auth(
|
||||
async def async_setup_auth( # noqa: C901
|
||||
hass: HomeAssistant,
|
||||
app: Application,
|
||||
) -> None:
|
||||
@@ -207,6 +219,41 @@ async def async_setup_auth(
|
||||
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
|
||||
return True
|
||||
|
||||
supervisor_user_id: str | None = None
|
||||
|
||||
async def async_authenticate_unix_socket(request: Request) -> bool:
|
||||
"""Authenticate a request from a Unix socket as the Supervisor user.
|
||||
|
||||
The Unix Socket is dedicated and only available to Supervisor. To
|
||||
avoid the extra overhead and round trips for the authentication and
|
||||
refresh tokens, we directly authenticate requests from the socket as
|
||||
the Supervisor user.
|
||||
"""
|
||||
nonlocal supervisor_user_id
|
||||
|
||||
# Fast path: use cached user ID
|
||||
if supervisor_user_id is not None:
|
||||
if user := await hass.auth.async_get_user(supervisor_user_id):
|
||||
request[KEY_HASS_USER] = user
|
||||
return True
|
||||
supervisor_user_id = None
|
||||
|
||||
# Slow path: find the Supervisor user by name
|
||||
for user in await hass.auth.async_get_users():
|
||||
if user.system_generated and user.name == HASSIO_USER_NAME:
|
||||
supervisor_user_id = user.id
|
||||
# Not setting KEY_HASS_REFRESH_TOKEN_ID since Supervisor user
|
||||
# doesn't use refresh tokens.
|
||||
request[KEY_HASS_USER] = user
|
||||
return True
|
||||
|
||||
# The Unix socket should not be serving before the hassio integration
|
||||
# has created the Supervisor user. If we get here, something is wrong.
|
||||
_LOGGER.error(
|
||||
"Supervisor user not found; cannot authenticate Unix socket request"
|
||||
)
|
||||
raise HTTPInternalServerError
|
||||
|
||||
@middleware
|
||||
async def auth_middleware(
|
||||
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||
@@ -214,7 +261,11 @@ async def async_setup_auth(
|
||||
"""Authenticate as middleware."""
|
||||
authenticated = False
|
||||
|
||||
if hdrs.AUTHORIZATION in request.headers and async_validate_auth_header(
|
||||
if is_unix_socket_request(request):
|
||||
authenticated = await async_authenticate_unix_socket(request)
|
||||
auth_type = "unix socket"
|
||||
|
||||
elif hdrs.AUTHORIZATION in request.headers and async_validate_auth_header(
|
||||
request
|
||||
):
|
||||
authenticated = True
|
||||
@@ -233,7 +284,7 @@ async def async_setup_auth(
|
||||
if authenticated and _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"Authenticated %s for %s using %s",
|
||||
request.remote,
|
||||
request.remote or "unknown",
|
||||
request.path,
|
||||
auth_type,
|
||||
)
|
||||
|
||||
@@ -30,7 +30,7 @@ from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.hassio import get_supervisor_ip, is_hassio
|
||||
from homeassistant.util import dt as dt_util, yaml as yaml_util
|
||||
|
||||
from .const import KEY_HASS
|
||||
from .const import KEY_HASS, is_unix_socket_request
|
||||
from .view import HomeAssistantView
|
||||
|
||||
_LOGGER: Final = logging.getLogger(__name__)
|
||||
@@ -72,6 +72,10 @@ async def ban_middleware(
|
||||
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||
) -> StreamResponse:
|
||||
"""IP Ban middleware."""
|
||||
# Unix socket connections are trusted, skip ban checks
|
||||
if is_unix_socket_request(request):
|
||||
return await handler(request)
|
||||
|
||||
if (ban_manager := request.app.get(KEY_BAN_MANAGER)) is None:
|
||||
_LOGGER.error("IP Ban middleware loaded but banned IPs not loaded")
|
||||
return await handler(request)
|
||||
|
||||
@@ -1,10 +1,22 @@
|
||||
"""HTTP specific constants."""
|
||||
|
||||
import socket
|
||||
from typing import Final
|
||||
|
||||
from aiohttp.web import Request
|
||||
|
||||
from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS # noqa: F401
|
||||
|
||||
DOMAIN: Final = "http"
|
||||
|
||||
KEY_HASS_USER: Final = "hass_user"
|
||||
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"
|
||||
|
||||
|
||||
def is_unix_socket_request(request: Request) -> bool:
|
||||
"""Check if request arrived over a Unix socket."""
|
||||
if (transport := request.transport) is None:
|
||||
return False
|
||||
if (sock := transport.get_extra_info("socket")) is None:
|
||||
return False
|
||||
return bool(sock.family == socket.AF_UNIX)
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
import socket
|
||||
from ssl import SSLContext
|
||||
|
||||
from aiohttp import web
|
||||
@@ -68,3 +70,62 @@ class HomeAssistantTCPSite(web.BaseSite):
|
||||
reuse_address=self._reuse_address,
|
||||
reuse_port=self._reuse_port,
|
||||
)
|
||||
|
||||
|
||||
class HomeAssistantUnixSite(web.BaseSite):
|
||||
"""HomeAssistant specific aiohttp UnixSite.
|
||||
|
||||
Listens on a Unix socket for local inter-process communication,
|
||||
used for Supervisor to Core communication.
|
||||
"""
|
||||
|
||||
__slots__ = ("_path",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runner: web.BaseRunner,
|
||||
path: Path,
|
||||
*,
|
||||
backlog: int = 128,
|
||||
) -> None:
|
||||
"""Initialize HomeAssistantUnixSite."""
|
||||
super().__init__(
|
||||
runner,
|
||||
backlog=backlog,
|
||||
)
|
||||
self._path = path
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return server URL."""
|
||||
return f"http://unix:{self._path}:"
|
||||
|
||||
def _create_unix_socket(self) -> socket.socket:
|
||||
"""Create and bind a Unix domain socket.
|
||||
|
||||
Performs blocking filesystem I/O (mkdir, unlink, chmod) and is
|
||||
intended to be run in an executor. Permissions are set after bind
|
||||
but before the socket is handed to the event loop, so no
|
||||
connections can arrive on an unrestricted socket.
|
||||
"""
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._path.unlink(missing_ok=True)
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.bind(str(self._path))
|
||||
except OSError:
|
||||
sock.close()
|
||||
raise
|
||||
self._path.chmod(0o600)
|
||||
return sock
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start server."""
|
||||
await super().start()
|
||||
loop = asyncio.get_running_loop()
|
||||
sock = await loop.run_in_executor(None, self._create_unix_socket)
|
||||
server = self._runner.server
|
||||
assert server is not None
|
||||
self._server = await loop.create_unix_server(
|
||||
server, sock=sock, backlog=self._backlog
|
||||
)
|
||||
|
||||
@@ -283,7 +283,10 @@ class IntegrationOnboardingView(_BaseOnboardingStepView):
|
||||
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
|
||||
"""Handle token creation."""
|
||||
hass = request.app[KEY_HASS]
|
||||
refresh_token_id = request[KEY_HASS_REFRESH_TOKEN_ID]
|
||||
if not (refresh_token_id := request.get(KEY_HASS_REFRESH_TOKEN_ID)):
|
||||
return self.json_message(
|
||||
"Refresh token not available", HTTPStatus.FORBIDDEN
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
if self._async_is_done():
|
||||
|
||||
@@ -10,6 +10,7 @@ import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant.components.http.ban import process_success_login, process_wrong_login
|
||||
from homeassistant.components.http.const import KEY_HASS_USER
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant
|
||||
from homeassistant.helpers.json import json_bytes
|
||||
@@ -68,6 +69,19 @@ class AuthPhase:
|
||||
# send_bytes_text will directly send a message to the client.
|
||||
self._send_bytes_text = send_bytes_text
|
||||
|
||||
async def async_handle_unix_socket(self) -> ActiveConnection:
|
||||
"""Handle a pre-authenticated Unix socket connection."""
|
||||
conn = ActiveConnection(
|
||||
self._logger,
|
||||
self._hass,
|
||||
self._send_message,
|
||||
self._request[KEY_HASS_USER],
|
||||
refresh_token=None,
|
||||
)
|
||||
await self._send_bytes_text(AUTH_OK_MESSAGE)
|
||||
self._logger.debug("Auth OK (unix socket)")
|
||||
return conn
|
||||
|
||||
async def async_handle(self, msg: JsonValueType) -> ActiveConnection:
|
||||
"""Handle authentication."""
|
||||
try:
|
||||
|
||||
@@ -59,14 +59,14 @@ class ActiveConnection:
|
||||
hass: HomeAssistant,
|
||||
send_message: Callable[[bytes | str | dict[str, Any]], None],
|
||||
user: User,
|
||||
refresh_token: RefreshToken,
|
||||
refresh_token: RefreshToken | None,
|
||||
) -> None:
|
||||
"""Initialize an active connection."""
|
||||
self.logger = logger
|
||||
self.hass = hass
|
||||
self.send_message = send_message
|
||||
self.user = user
|
||||
self.refresh_token_id = refresh_token.id
|
||||
self.refresh_token_id = refresh_token.id if refresh_token else None
|
||||
self.subscriptions: dict[Hashable, Callable[[], Any]] = {}
|
||||
self.last_id = 0
|
||||
self.can_coalesce = False
|
||||
|
||||
@@ -14,6 +14,7 @@ from aiohttp import WSMsgType, web
|
||||
from aiohttp.http_websocket import WebSocketWriter
|
||||
|
||||
from homeassistant.components.http import KEY_HASS, HomeAssistantView
|
||||
from homeassistant.components.http.const import is_unix_socket_request
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, EVENT_LOGGING_CHANGED
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
@@ -36,12 +37,12 @@ from .error import Disconnect
|
||||
from .messages import message_to_json_bytes
|
||||
from .util import describe_request
|
||||
|
||||
CLOSE_MSG_TYPES = {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING}
|
||||
AUTH_MESSAGE_TIMEOUT = 10 # seconds
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .connection import ActiveConnection
|
||||
|
||||
CLOSE_MSG_TYPES = {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING}
|
||||
AUTH_MESSAGE_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
_WS_LOGGER: Final = logging.getLogger(f"{__name__}.connection")
|
||||
|
||||
@@ -386,37 +387,45 @@ class WebSocketHandler:
|
||||
send_bytes_text: Callable[[bytes], Coroutine[Any, Any, None]],
|
||||
) -> ActiveConnection:
|
||||
"""Handle the auth phase of the websocket connection."""
|
||||
await send_bytes_text(AUTH_REQUIRED_MESSAGE)
|
||||
request = self._request
|
||||
|
||||
# Auth Phase
|
||||
try:
|
||||
msg = await self._wsock.receive(AUTH_MESSAGE_TIMEOUT)
|
||||
except TimeoutError as err:
|
||||
raise Disconnect(
|
||||
f"Did not receive auth message within {AUTH_MESSAGE_TIMEOUT} seconds"
|
||||
) from err
|
||||
if is_unix_socket_request(request):
|
||||
# Unix socket requests are pre-authenticated by the HTTP
|
||||
# auth middleware — skip the token exchange.
|
||||
connection = await auth.async_handle_unix_socket()
|
||||
else:
|
||||
await send_bytes_text(AUTH_REQUIRED_MESSAGE)
|
||||
|
||||
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
||||
raise Disconnect("Received close message during auth phase")
|
||||
|
||||
if msg.type is not WSMsgType.TEXT:
|
||||
if msg.type is WSMsgType.ERROR:
|
||||
# msg.data is the exception
|
||||
# Auth Phase
|
||||
try:
|
||||
msg = await self._wsock.receive(AUTH_MESSAGE_TIMEOUT)
|
||||
except TimeoutError as err:
|
||||
raise Disconnect(
|
||||
f"Received error message during auth phase: {msg.data}"
|
||||
f"Did not receive auth message within {AUTH_MESSAGE_TIMEOUT} seconds"
|
||||
) from err
|
||||
|
||||
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.CLOSING):
|
||||
raise Disconnect("Received close message during auth phase")
|
||||
|
||||
if msg.type is not WSMsgType.TEXT:
|
||||
if msg.type is WSMsgType.ERROR:
|
||||
# msg.data is the exception
|
||||
raise Disconnect(
|
||||
f"Received error message during auth phase: {msg.data}"
|
||||
)
|
||||
raise Disconnect(
|
||||
f"Received non-Text message of type {msg.type} during auth phase"
|
||||
)
|
||||
raise Disconnect(
|
||||
f"Received non-Text message of type {msg.type} during auth phase"
|
||||
)
|
||||
|
||||
try:
|
||||
auth_msg_data = json_loads(msg.data)
|
||||
except ValueError as err:
|
||||
raise Disconnect("Received invalid JSON during auth phase") from err
|
||||
try:
|
||||
auth_msg_data = json_loads(msg.data)
|
||||
except ValueError as err:
|
||||
raise Disconnect("Received invalid JSON during auth phase") from err
|
||||
|
||||
if self._debug:
|
||||
self._logger.debug("%s: Received %s", self.description, auth_msg_data)
|
||||
connection = await auth.async_handle(auth_msg_data)
|
||||
|
||||
if self._debug:
|
||||
self._logger.debug("%s: Received %s", self.description, auth_msg_data)
|
||||
connection = await auth.async_handle(auth_msg_data)
|
||||
# As the webserver is now started before the start
|
||||
# event we do not want to block for websocket responses
|
||||
#
|
||||
|
||||
@@ -13,7 +13,7 @@ import jwt
|
||||
import pytest
|
||||
import yarl
|
||||
|
||||
from homeassistant.auth.const import GROUP_ID_READ_ONLY
|
||||
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
|
||||
from homeassistant.auth.models import User
|
||||
from homeassistant.auth.providers import trusted_networks
|
||||
from homeassistant.auth.providers.homeassistant import HassAuthProvider
|
||||
@@ -32,6 +32,7 @@ from homeassistant.components.http.request_context import (
|
||||
current_request,
|
||||
setup_request_context,
|
||||
)
|
||||
from homeassistant.const import HASSIO_USER_NAME
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS
|
||||
from homeassistant.setup import async_setup_component
|
||||
@@ -658,3 +659,78 @@ async def test_create_user_once(hass: HomeAssistant) -> None:
|
||||
|
||||
# test it did not create a user
|
||||
assert len(await hass.auth.async_get_users()) == cur_users + 1
|
||||
|
||||
|
||||
async def test_unix_socket_auth_with_supervisor_user(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test that Unix socket requests are authenticated as Supervisor user."""
|
||||
supervisor_user = await hass.auth.async_create_system_user(
|
||||
HASSIO_USER_NAME, group_ids=[GROUP_ID_ADMIN]
|
||||
)
|
||||
await hass.auth.async_create_refresh_token(supervisor_user)
|
||||
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.auth.is_unix_socket_request", return_value=True
|
||||
):
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.OK
|
||||
data = await req.json()
|
||||
assert data["user_id"] == supervisor_user.id
|
||||
|
||||
|
||||
async def test_unix_socket_auth_without_supervisor_user(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test that Unix socket requests return 500 when no Supervisor user exists."""
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.auth.is_unix_socket_request", return_value=True
|
||||
):
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
async def test_unix_socket_auth_caches_user_id(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test that Unix socket auth caches the Supervisor user ID."""
|
||||
supervisor_user = await hass.auth.async_create_system_user(
|
||||
HASSIO_USER_NAME, group_ids=[GROUP_ID_ADMIN]
|
||||
)
|
||||
await hass.auth.async_create_refresh_token(supervisor_user)
|
||||
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.auth.is_unix_socket_request", return_value=True
|
||||
):
|
||||
# First request triggers user lookup
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.OK
|
||||
|
||||
# Second request should use cached user ID
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.http.auth.is_unix_socket_request",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
hass.auth, "async_get_users", wraps=hass.auth.async_get_users
|
||||
) as mock_get_users,
|
||||
):
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.OK
|
||||
mock_get_users.assert_not_called()
|
||||
|
||||
@@ -466,3 +466,33 @@ async def test_single_ban_file_entry(
|
||||
await manager.async_add_ban(remote_ip)
|
||||
|
||||
assert m_open.call_count == 1
|
||||
|
||||
|
||||
async def test_unix_socket_skips_ban_check(
|
||||
hass: HomeAssistant, aiohttp_client: ClientSessionGenerator
|
||||
) -> None:
|
||||
"""Test that Unix socket requests bypass ban middleware."""
|
||||
app = web.Application()
|
||||
app[KEY_HASS] = hass
|
||||
setup_bans(hass, app, 5)
|
||||
set_real_ip = mock_real_ip(app)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.ban.load_yaml_config_file",
|
||||
return_value={
|
||||
banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS
|
||||
},
|
||||
):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
# Verify the IP is actually banned for normal requests
|
||||
set_real_ip(BANNED_IPS[0])
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.FORBIDDEN
|
||||
|
||||
# Unix socket requests should bypass ban checks
|
||||
with patch(
|
||||
"homeassistant.components.http.ban.is_unix_socket_request", return_value=True
|
||||
):
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.NOT_FOUND
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import timedelta
|
||||
from http import HTTPStatus
|
||||
from ipaddress import ip_network
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
@@ -14,6 +15,7 @@ import pytest
|
||||
from homeassistant.auth.providers.homeassistant import HassAuthProvider
|
||||
from homeassistant.components import cloud, http
|
||||
from homeassistant.components.cloud import CloudNotAvailable
|
||||
from homeassistant.const import HASSIO_USER_NAME
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import issue_registry as ir
|
||||
from homeassistant.helpers.http import KEY_HASS
|
||||
@@ -735,3 +737,74 @@ async def test_server_host(
|
||||
)
|
||||
|
||||
assert set(issue_registry.issues) == expected_issues
|
||||
|
||||
|
||||
async def test_unix_socket_started_with_supervisor(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test unix socket is started when running under Supervisor."""
|
||||
await hass.auth.async_create_system_user(
|
||||
HASSIO_USER_NAME, group_ids=["system-admin"]
|
||||
)
|
||||
socket_path = tmp_path / "core.sock"
|
||||
loop = asyncio.get_running_loop()
|
||||
mock_sock = Mock()
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ, {"SUPERVISOR_CORE_API_SOCKET": str(socket_path)}, clear=False
|
||||
),
|
||||
patch("asyncio.BaseEventLoop.create_server", return_value=Mock()),
|
||||
patch(
|
||||
"homeassistant.components.http.web_runner.HomeAssistantUnixSite"
|
||||
"._create_unix_socket",
|
||||
return_value=mock_sock,
|
||||
) as mock_create_sock,
|
||||
patch.object(
|
||||
loop, "create_unix_server", return_value=Mock()
|
||||
) as mock_create_unix,
|
||||
):
|
||||
assert await async_setup_component(hass, "http", {"http": {}})
|
||||
await hass.async_start()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mock_create_sock.assert_called_once()
|
||||
mock_create_unix.assert_called_once_with(ANY, sock=mock_sock, backlog=128)
|
||||
assert hass.http.unix_site is not None
|
||||
|
||||
|
||||
async def test_unix_socket_not_started_without_supervisor(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test unix socket is not started when not running under Supervisor."""
|
||||
with (
|
||||
patch.dict(os.environ, {}, clear=False),
|
||||
patch("asyncio.BaseEventLoop.create_server", return_value=Mock()),
|
||||
):
|
||||
os.environ.pop("SUPERVISOR_CORE_API_SOCKET", None)
|
||||
assert await async_setup_component(hass, "http", {"http": {}})
|
||||
await hass.async_start()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert hass.http.unix_site is None
|
||||
|
||||
|
||||
async def test_unix_socket_rejected_relative_path(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test unix socket is rejected when path is relative."""
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{"SUPERVISOR_CORE_API_SOCKET": "relative/path.sock"},
|
||||
clear=False,
|
||||
),
|
||||
patch("asyncio.BaseEventLoop.create_server", return_value=Mock()),
|
||||
):
|
||||
assert await async_setup_component(hass, "http", {"http": {}})
|
||||
await hass.async_start()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert hass.http.unix_site is None
|
||||
assert "path must be absolute" in caplog.text
|
||||
|
||||
@@ -18,6 +18,7 @@ from homeassistant.components.websocket_api.const import (
|
||||
SIGNAL_WEBSOCKET_DISCONNECTED,
|
||||
URL,
|
||||
)
|
||||
from homeassistant.const import HASSIO_USER_NAME
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.setup import async_setup_component
|
||||
@@ -367,3 +368,43 @@ async def test_error_right_after_auth_disconnects(
|
||||
assert close_error_msg.type is WSMsgType.CLOSE
|
||||
|
||||
assert "Received error message during command phase: explode" in caplog.text
|
||||
|
||||
|
||||
async def test_unix_socket_auth_bypass(
|
||||
hass: HomeAssistant, hass_client_no_auth: ClientSessionGenerator
|
||||
) -> None:
|
||||
"""Test that Unix socket connections skip websocket auth phase."""
|
||||
# Create the Supervisor system user
|
||||
await hass.auth.async_create_system_user(
|
||||
HASSIO_USER_NAME, group_ids=["system-admin"]
|
||||
)
|
||||
|
||||
assert await async_setup_component(hass, "websocket_api", {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.http.ban.is_unix_socket_request",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.http.auth.is_unix_socket_request",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.websocket_api.http.is_unix_socket_request",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
async with client.ws_connect(URL) as ws:
|
||||
# Should immediately receive auth_ok without sending a token
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg["type"] == TYPE_AUTH_OK
|
||||
|
||||
# Verify the connection works by sending a ping
|
||||
await ws.send_json({"id": 1, "type": "ping"})
|
||||
pong_msg = await ws.receive_json()
|
||||
assert pong_msg["type"] == "pong"
|
||||
assert pong_msg["id"] == 1
|
||||
|
||||
Reference in New Issue
Block a user