mirror of
https://github.com/home-assistant/core.git
synced 2026-02-24 11:11:16 +01:00
Compare commits
3 Commits
dev
...
use-unix-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c5889082c0 | ||
|
|
68d94badc6 | ||
|
|
275374ec0d |
@@ -10,6 +10,7 @@ from functools import partial
|
||||
from ipaddress import IPv4Network, IPv6Network, ip_network
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import socket
|
||||
import ssl
|
||||
from tempfile import NamedTemporaryFile
|
||||
@@ -69,7 +70,7 @@ from .headers import setup_headers
|
||||
from .request_context import setup_request_context
|
||||
from .security_filter import setup_security_filter
|
||||
from .static import CACHE_HEADERS, CachingStaticResource
|
||||
from .web_runner import HomeAssistantTCPSite
|
||||
from .web_runner import HomeAssistantTCPSite, HomeAssistantUnixSite
|
||||
|
||||
CONF_SERVER_HOST: Final = "server_host"
|
||||
CONF_SERVER_PORT: Final = "server_port"
|
||||
@@ -235,6 +236,16 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
||||
source_ip_task = create_eager_task(async_get_source_ip(hass))
|
||||
|
||||
unix_socket_path: Path | None = None
|
||||
if socket_env := os.environ.get("SUPERVISOR_CORE_API_SOCKET"):
|
||||
socket_path = Path(socket_env)
|
||||
if socket_path.is_absolute():
|
||||
unix_socket_path = socket_path
|
||||
else:
|
||||
_LOGGER.error(
|
||||
"Invalid unix socket path %s: path must be absolute", socket_env
|
||||
)
|
||||
|
||||
server = HomeAssistantHTTP(
|
||||
hass,
|
||||
server_host=server_host,
|
||||
@@ -244,6 +255,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
ssl_key=ssl_key,
|
||||
trusted_proxies=trusted_proxies,
|
||||
ssl_profile=ssl_profile,
|
||||
unix_socket_path=unix_socket_path,
|
||||
)
|
||||
await server.async_initialize(
|
||||
cors_origins=cors_origins,
|
||||
@@ -366,6 +378,7 @@ class HomeAssistantHTTP:
|
||||
server_port: int,
|
||||
trusted_proxies: list[IPv4Network | IPv6Network],
|
||||
ssl_profile: str,
|
||||
unix_socket_path: Path | None = None,
|
||||
) -> None:
|
||||
"""Initialize the HTTP Home Assistant server."""
|
||||
self.app = HomeAssistantApplication(
|
||||
@@ -384,8 +397,10 @@ class HomeAssistantHTTP:
|
||||
self.server_port = server_port
|
||||
self.trusted_proxies = trusted_proxies
|
||||
self.ssl_profile = ssl_profile
|
||||
self.unix_socket_path = unix_socket_path
|
||||
self.runner: web.AppRunner | None = None
|
||||
self.site: HomeAssistantTCPSite | None = None
|
||||
self.unix_site: HomeAssistantUnixSite | None = None
|
||||
self.context: ssl.SSLContext | None = None
|
||||
|
||||
async def async_initialize(
|
||||
@@ -623,6 +638,20 @@ class HomeAssistantHTTP:
|
||||
)
|
||||
await self.runner.setup()
|
||||
|
||||
if self.unix_socket_path is not None:
|
||||
self.unix_site = HomeAssistantUnixSite(self.runner, self.unix_socket_path)
|
||||
try:
|
||||
await self.unix_site.start()
|
||||
except OSError as error:
|
||||
_LOGGER.error(
|
||||
"Failed to create HTTP server on unix socket %s: %s",
|
||||
self.unix_socket_path,
|
||||
error,
|
||||
)
|
||||
self.unix_site = None
|
||||
else:
|
||||
_LOGGER.info("Now listening on unix socket %s", self.unix_socket_path)
|
||||
|
||||
self.site = HomeAssistantTCPSite(
|
||||
self.runner, self.server_host, self.server_port, ssl_context=self.context
|
||||
)
|
||||
@@ -637,6 +666,10 @@ class HomeAssistantHTTP:
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the aiohttp server."""
|
||||
if self.unix_site is not None:
|
||||
await self.unix_site.stop()
|
||||
if self.unix_socket_path is not None:
|
||||
self.unix_socket_path.unlink(missing_ok=True)
|
||||
if self.site is not None:
|
||||
await self.site.stop()
|
||||
if self.runner is not None:
|
||||
|
||||
@@ -20,6 +20,7 @@ from homeassistant.auth import jwt_wrapper
|
||||
from homeassistant.auth.const import GROUP_ID_READ_ONLY
|
||||
from homeassistant.auth.models import User
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.const import HASSIO_USER_NAME
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.http import current_request
|
||||
from homeassistant.helpers.json import json_bytes
|
||||
@@ -27,7 +28,12 @@ from homeassistant.helpers.network import is_cloud_connection
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.util.network import is_local
|
||||
|
||||
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
|
||||
from .const import (
|
||||
KEY_AUTHENTICATED,
|
||||
KEY_HASS_REFRESH_TOKEN_ID,
|
||||
KEY_HASS_USER,
|
||||
is_unix_socket_request,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -117,7 +123,7 @@ def async_user_not_allowed_do_auth(
|
||||
return "User cannot authenticate remotely"
|
||||
|
||||
|
||||
async def async_setup_auth(
|
||||
async def async_setup_auth( # noqa: C901
|
||||
hass: HomeAssistant,
|
||||
app: Application,
|
||||
) -> None:
|
||||
@@ -207,6 +213,27 @@ async def async_setup_auth(
|
||||
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
|
||||
return True
|
||||
|
||||
supervisor_user_id: str | None = None
|
||||
|
||||
async def async_authenticate_unix_socket(request: Request) -> bool:
|
||||
"""Authenticate a request from a Unix socket as the Supervisor user."""
|
||||
nonlocal supervisor_user_id
|
||||
|
||||
# Fast path: use cached user ID
|
||||
if supervisor_user_id is not None:
|
||||
if user := await hass.auth.async_get_user(supervisor_user_id):
|
||||
request[KEY_HASS_USER] = user
|
||||
return True
|
||||
supervisor_user_id = None
|
||||
|
||||
# Slow path: find the Supervisor user by name
|
||||
for user in await hass.auth.async_get_users():
|
||||
if user.system_generated and user.name == HASSIO_USER_NAME:
|
||||
supervisor_user_id = user.id
|
||||
request[KEY_HASS_USER] = user
|
||||
return True
|
||||
return False
|
||||
|
||||
@middleware
|
||||
async def auth_middleware(
|
||||
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||
@@ -214,7 +241,11 @@ async def async_setup_auth(
|
||||
"""Authenticate as middleware."""
|
||||
authenticated = False
|
||||
|
||||
if hdrs.AUTHORIZATION in request.headers and async_validate_auth_header(
|
||||
if is_unix_socket_request(request):
|
||||
authenticated = await async_authenticate_unix_socket(request)
|
||||
auth_type = "unix socket"
|
||||
|
||||
elif hdrs.AUTHORIZATION in request.headers and async_validate_auth_header(
|
||||
request
|
||||
):
|
||||
authenticated = True
|
||||
@@ -233,7 +264,7 @@ async def async_setup_auth(
|
||||
if authenticated and _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"Authenticated %s for %s using %s",
|
||||
request.remote,
|
||||
request.remote or "unknown",
|
||||
request.path,
|
||||
auth_type,
|
||||
)
|
||||
|
||||
@@ -30,7 +30,7 @@ from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.hassio import get_supervisor_ip, is_hassio
|
||||
from homeassistant.util import dt as dt_util, yaml as yaml_util
|
||||
|
||||
from .const import KEY_HASS
|
||||
from .const import KEY_HASS, is_unix_socket_request
|
||||
from .view import HomeAssistantView
|
||||
|
||||
_LOGGER: Final = logging.getLogger(__name__)
|
||||
@@ -72,6 +72,10 @@ async def ban_middleware(
|
||||
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||
) -> StreamResponse:
|
||||
"""IP Ban middleware."""
|
||||
# Unix socket connections are trusted, skip ban checks
|
||||
if is_unix_socket_request(request):
|
||||
return await handler(request)
|
||||
|
||||
if (ban_manager := request.app.get(KEY_BAN_MANAGER)) is None:
|
||||
_LOGGER.error("IP Ban middleware loaded but banned IPs not loaded")
|
||||
return await handler(request)
|
||||
|
||||
@@ -1,10 +1,22 @@
|
||||
"""HTTP specific constants."""
|
||||
|
||||
import socket
|
||||
from typing import Final
|
||||
|
||||
from aiohttp.web import Request
|
||||
|
||||
from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS # noqa: F401
|
||||
|
||||
DOMAIN: Final = "http"
|
||||
|
||||
KEY_HASS_USER: Final = "hass_user"
|
||||
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"
|
||||
|
||||
|
||||
def is_unix_socket_request(request: Request) -> bool:
|
||||
"""Check if request arrived over a Unix socket."""
|
||||
if (transport := request.transport) is None:
|
||||
return False
|
||||
if (sock := transport.get_extra_info("socket")) is None:
|
||||
return False
|
||||
return bool(sock.family == socket.AF_UNIX)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from ssl import SSLContext
|
||||
|
||||
from aiohttp import web
|
||||
@@ -68,3 +69,46 @@ class HomeAssistantTCPSite(web.BaseSite):
|
||||
reuse_address=self._reuse_address,
|
||||
reuse_port=self._reuse_port,
|
||||
)
|
||||
|
||||
|
||||
class HomeAssistantUnixSite(web.BaseSite):
|
||||
"""HomeAssistant specific aiohttp UnixSite.
|
||||
|
||||
Listens on a Unix socket for local inter-process communication,
|
||||
used for Supervisor to Core communication.
|
||||
"""
|
||||
|
||||
__slots__ = ("_path",)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runner: web.BaseRunner,
|
||||
path: Path,
|
||||
*,
|
||||
backlog: int = 128,
|
||||
) -> None:
|
||||
"""Initialize HomeAssistantUnixSite."""
|
||||
super().__init__(
|
||||
runner,
|
||||
backlog=backlog,
|
||||
)
|
||||
self._path = path
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return server URL."""
|
||||
return f"http://unix:{self._path}:"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start server."""
|
||||
await super().start()
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._path.unlink(missing_ok=True)
|
||||
loop = asyncio.get_running_loop()
|
||||
server = self._runner.server
|
||||
assert server is not None
|
||||
self._server = await loop.create_unix_server(
|
||||
server,
|
||||
self._path,
|
||||
backlog=self._backlog,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@ import jwt
|
||||
import pytest
|
||||
import yarl
|
||||
|
||||
from homeassistant.auth.const import GROUP_ID_READ_ONLY
|
||||
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
|
||||
from homeassistant.auth.models import User
|
||||
from homeassistant.auth.providers import trusted_networks
|
||||
from homeassistant.auth.providers.homeassistant import HassAuthProvider
|
||||
@@ -32,6 +32,7 @@ from homeassistant.components.http.request_context import (
|
||||
current_request,
|
||||
setup_request_context,
|
||||
)
|
||||
from homeassistant.const import HASSIO_USER_NAME
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS
|
||||
from homeassistant.setup import async_setup_component
|
||||
@@ -658,3 +659,78 @@ async def test_create_user_once(hass: HomeAssistant) -> None:
|
||||
|
||||
# test it did not create a user
|
||||
assert len(await hass.auth.async_get_users()) == cur_users + 1
|
||||
|
||||
|
||||
async def test_unix_socket_auth_with_supervisor_user(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test that Unix socket requests are authenticated as Supervisor user."""
|
||||
supervisor_user = await hass.auth.async_create_system_user(
|
||||
HASSIO_USER_NAME, group_ids=[GROUP_ID_ADMIN]
|
||||
)
|
||||
await hass.auth.async_create_refresh_token(supervisor_user)
|
||||
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.auth.is_unix_socket_request", return_value=True
|
||||
):
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.OK
|
||||
data = await req.json()
|
||||
assert data["user_id"] == supervisor_user.id
|
||||
|
||||
|
||||
async def test_unix_socket_auth_without_supervisor_user(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test that Unix socket requests fail when no Supervisor user exists."""
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.auth.is_unix_socket_request", return_value=True
|
||||
):
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||
|
||||
|
||||
async def test_unix_socket_auth_caches_user_id(
|
||||
hass: HomeAssistant,
|
||||
app: web.Application,
|
||||
aiohttp_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test that Unix socket auth caches the Supervisor user ID."""
|
||||
supervisor_user = await hass.auth.async_create_system_user(
|
||||
HASSIO_USER_NAME, group_ids=[GROUP_ID_ADMIN]
|
||||
)
|
||||
await hass.auth.async_create_refresh_token(supervisor_user)
|
||||
|
||||
await async_setup_auth(hass, app)
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.auth.is_unix_socket_request", return_value=True
|
||||
):
|
||||
# First request triggers user lookup
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.OK
|
||||
|
||||
# Second request should use cached user ID
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.http.auth.is_unix_socket_request",
|
||||
return_value=True,
|
||||
),
|
||||
patch.object(
|
||||
hass.auth, "async_get_users", wraps=hass.auth.async_get_users
|
||||
) as mock_get_users,
|
||||
):
|
||||
req = await client.get("/")
|
||||
assert req.status == HTTPStatus.OK
|
||||
mock_get_users.assert_not_called()
|
||||
|
||||
@@ -465,3 +465,33 @@ async def test_single_ban_file_entry(
|
||||
await manager.async_add_ban(remote_ip)
|
||||
|
||||
assert m_open.call_count == 1
|
||||
|
||||
|
||||
async def test_unix_socket_skips_ban_check(
|
||||
hass: HomeAssistant, aiohttp_client: ClientSessionGenerator
|
||||
) -> None:
|
||||
"""Test that Unix socket requests bypass ban middleware."""
|
||||
app = web.Application()
|
||||
app[KEY_HASS] = hass
|
||||
setup_bans(hass, app, 5)
|
||||
set_real_ip = mock_real_ip(app)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.http.ban.load_yaml_config_file",
|
||||
return_value={
|
||||
banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS
|
||||
},
|
||||
):
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
# Verify the IP is actually banned for normal requests
|
||||
set_real_ip(BANNED_IPS[0])
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.FORBIDDEN
|
||||
|
||||
# Unix socket requests should bypass ban checks
|
||||
with patch(
|
||||
"homeassistant.components.http.ban.is_unix_socket_request", return_value=True
|
||||
):
|
||||
resp = await client.get("/")
|
||||
assert resp.status == HTTPStatus.NOT_FOUND
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import timedelta
|
||||
from http import HTTPStatus
|
||||
from ipaddress import ip_network
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
@@ -735,3 +736,81 @@ async def test_server_host(
|
||||
)
|
||||
|
||||
assert set(issue_registry.issues) == expected_issues
|
||||
|
||||
|
||||
async def test_unix_socket_started_with_supervisor(
|
||||
hass: HomeAssistant,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test unix socket is started when running under Supervisor."""
|
||||
socket_path = tmp_path / "core.sock"
|
||||
mock_server = Mock()
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ, {"SUPERVISOR_CORE_API_SOCKET": str(socket_path)}, clear=False
|
||||
),
|
||||
patch("asyncio.BaseEventLoop.create_server", return_value=mock_server),
|
||||
patch(
|
||||
"asyncio.unix_events._UnixSelectorEventLoop.create_unix_server",
|
||||
return_value=mock_server,
|
||||
) as mock_create_unix,
|
||||
):
|
||||
assert await async_setup_component(hass, "http", {"http": {}})
|
||||
await hass.async_start()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mock_create_unix.assert_called_once_with(
|
||||
ANY,
|
||||
socket_path,
|
||||
backlog=128,
|
||||
)
|
||||
assert hass.http.unix_site is not None
|
||||
|
||||
|
||||
async def test_unix_socket_not_started_without_supervisor(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test unix socket is not started when not running under Supervisor."""
|
||||
mock_server = Mock()
|
||||
with (
|
||||
patch.dict(os.environ, {}, clear=False),
|
||||
patch("asyncio.BaseEventLoop.create_server", return_value=mock_server),
|
||||
patch(
|
||||
"asyncio.unix_events._UnixSelectorEventLoop.create_unix_server",
|
||||
return_value=mock_server,
|
||||
) as mock_create_unix,
|
||||
):
|
||||
os.environ.pop("SUPERVISOR_CORE_API_SOCKET", None)
|
||||
assert await async_setup_component(hass, "http", {"http": {}})
|
||||
await hass.async_start()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mock_create_unix.assert_not_called()
|
||||
assert hass.http.unix_site is None
|
||||
|
||||
|
||||
async def test_unix_socket_rejected_relative_path(
|
||||
hass: HomeAssistant,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test unix socket is rejected when path is relative."""
|
||||
mock_server = Mock()
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{"SUPERVISOR_CORE_API_SOCKET": "relative/path.sock"},
|
||||
clear=False,
|
||||
),
|
||||
patch("asyncio.BaseEventLoop.create_server", return_value=mock_server),
|
||||
patch(
|
||||
"asyncio.unix_events._UnixSelectorEventLoop.create_unix_server",
|
||||
return_value=mock_server,
|
||||
) as mock_create_unix,
|
||||
):
|
||||
assert await async_setup_component(hass, "http", {"http": {}})
|
||||
await hass.async_start()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mock_create_unix.assert_not_called()
|
||||
assert hass.http.unix_site is None
|
||||
assert "path must be absolute" in caplog.text
|
||||
|
||||
Reference in New Issue
Block a user