Ensure asyncio blocking checks are undone after tests run

This commit is contained in:
J. Nick Koston
2024-06-12 16:23:32 -05:00
parent 51891ff8e2
commit ed6cac4e4f
2 changed files with 164 additions and 51 deletions

View File

@@ -1,7 +1,9 @@
"""Block blocking calls being done in asyncio."""
import builtins
from collections.abc import Callable
from contextlib import suppress
from dataclasses import dataclass
import glob
from http.client import HTTPConnection
import importlib
@@ -9,6 +11,7 @@ import os
import sys
import threading
import time
from types import ModuleType
from typing import Any
from .helpers.frame import get_current_frame
@@ -46,53 +49,139 @@ def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool:
return False
@dataclass(slots=True, frozen=True)
class BlockingCall:
"""Class to hold information about a blocking call."""
original_func: Callable
module: ModuleType
function: str
check_allowed: Callable[[dict[str, Any]], bool] | None
strict: bool
strict_core: bool
skip_for_tests: bool
BLOCKING_CALLS: tuple[BlockingCall, ...] = (
BlockingCall(
original_func=HTTPConnection.putrequest,
module=HTTPConnection, # type: ignore[arg-type]
function="putrequest",
check_allowed=None,
strict=True,
strict_core=True,
skip_for_tests=False,
),
BlockingCall(
original_func=time.sleep,
module=time,
function="sleep",
check_allowed=_check_sleep_call_allowed,
strict=True,
strict_core=True,
skip_for_tests=False,
),
BlockingCall(
original_func=glob.glob,
module=glob,
function="glob",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=False,
),
BlockingCall(
original_func=glob.iglob,
module=glob,
function="iglob",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=False,
),
BlockingCall(
original_func=os.walk,
module=os,
function="walk",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=False,
),
BlockingCall(
original_func=os.listdir,
module=os,
function="listdir",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=True,
),
BlockingCall(
original_func=os.scandir,
module=os,
function="scandir",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=True,
),
BlockingCall(
original_func=builtins.open,
module=builtins,
function="open",
check_allowed=_check_file_allowed,
strict=False,
strict_core=False,
skip_for_tests=True,
),
BlockingCall(
original_func=importlib.import_module,
module=importlib,
function="import_module",
check_allowed=_check_import_call_allowed,
strict=False,
strict_core=False,
skip_for_tests=True,
),
)
@dataclass(slots=True)
class BlockedCalls:
"""Class to track which calls are blocked."""
calls: set[BlockingCall]
_BLOCKED_CALLS = BlockedCalls(set())
def enable() -> None:
"""Enable the detection of blocking calls in the event loop."""
if _BLOCKED_CALLS.calls:
return
loop_thread_id = threading.get_ident()
# Prevent urllib3 and requests doing I/O in event loop
HTTPConnection.putrequest = protect_loop( # type: ignore[method-assign]
HTTPConnection.putrequest, loop_thread_id=loop_thread_id
)
for blocking_call in BLOCKING_CALLS:
if _IN_TESTS and blocking_call.skip_for_tests:
continue
# Prevent sleeping in event loop.
time.sleep = protect_loop(
time.sleep,
check_allowed=_check_sleep_call_allowed,
loop_thread_id=loop_thread_id,
)
glob.glob = protect_loop(
glob.glob, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
glob.iglob = protect_loop(
glob.iglob, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
os.walk = protect_loop(
os.walk, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
if not _IN_TESTS:
# Prevent files being opened inside the event loop
os.listdir = protect_loop( # type: ignore[assignment]
os.listdir, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
os.scandir = protect_loop( # type: ignore[assignment]
os.scandir, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
builtins.open = protect_loop( # type: ignore[assignment]
builtins.open,
strict_core=False,
strict=False,
check_allowed=_check_file_allowed,
protected_function = protect_loop(
blocking_call.original_func,
strict=blocking_call.strict,
strict_core=blocking_call.strict_core,
check_allowed=blocking_call.check_allowed,
loop_thread_id=loop_thread_id,
)
# unittest uses `importlib.import_module` to do mocking
# so we cannot protect it if we are running tests
importlib.import_module = protect_loop(
importlib.import_module,
strict_core=False,
strict=False,
check_allowed=_check_import_call_allowed,
loop_thread_id=loop_thread_id,
setattr(blocking_call.module, blocking_call.function, protected_function)
_BLOCKED_CALLS.calls.add(blocking_call)
def disable() -> None:
"""Disable the detection of blocking calls in the event loop."""
for blocking_call in _BLOCKED_CALLS.calls:
setattr(
blocking_call.module, blocking_call.function, blocking_call.original_func
)
_BLOCKED_CALLS.calls.clear()

View File

@@ -17,6 +17,13 @@ from homeassistant.core import HomeAssistant
from .common import extract_stack_to_frame
@pytest.fixture(autouse=True)
def unpatch_block_async_io():
"""Unpatch block_async_io after each test."""
yield
block_async_io.disable()
async def test_protect_loop_debugger_sleep(caplog: pytest.LogCaptureFixture) -> None:
"""Test time.sleep injected by the debugger is not reported."""
block_async_io.enable()
@@ -214,7 +221,8 @@ async def test_protect_loop_open(caplog: pytest.LogCaptureFixture) -> None:
async def test_protect_open(caplog: pytest.LogCaptureFixture) -> None:
"""Test opening a file in the event loop logs."""
block_async_io.enable()
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
with contextlib.suppress(FileNotFoundError):
open("/config/data_not_exist", encoding="utf8").close()
@@ -231,7 +239,8 @@ async def test_protect_open(caplog: pytest.LogCaptureFixture) -> None:
)
async def test_protect_open_path(path: Any, caplog: pytest.LogCaptureFixture) -> None:
"""Test opening a file by path in the event loop logs."""
block_async_io.enable()
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
with contextlib.suppress(FileNotFoundError):
open(path, encoding="utf8").close()
@@ -242,7 +251,8 @@ async def test_protect_loop_glob(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test glob calls in the loop are logged."""
block_async_io.enable()
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
glob.glob("/dev/null")
assert "Detected blocking call to glob with args" in caplog.text
caplog.clear()
@@ -254,7 +264,8 @@ async def test_protect_loop_iglob(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test iglob calls in the loop are logged."""
block_async_io.enable()
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
glob.iglob("/dev/null")
assert "Detected blocking call to iglob with args" in caplog.text
caplog.clear()
@@ -266,7 +277,8 @@ async def test_protect_loop_scandir(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test glob calls in the loop are logged."""
block_async_io.enable()
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
with contextlib.suppress(FileNotFoundError):
os.scandir("/path/that/does/not/exists")
assert "Detected blocking call to scandir with args" in caplog.text
@@ -280,7 +292,8 @@ async def test_protect_loop_listdir(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test listdir calls in the loop are logged."""
block_async_io.enable()
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
with contextlib.suppress(FileNotFoundError):
os.listdir("/path/that/does/not/exists")
assert "Detected blocking call to listdir with args" in caplog.text
@@ -293,8 +306,9 @@ async def test_protect_loop_listdir(
async def test_protect_loop_walk(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test glob calls in the loop are logged."""
block_async_io.enable()
"""Test os.walk calls in the loop are logged."""
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable()
with contextlib.suppress(FileNotFoundError):
os.walk("/path/that/does/not/exists")
assert "Detected blocking call to walk with args" in caplog.text
@@ -302,3 +316,13 @@ async def test_protect_loop_walk(
with contextlib.suppress(FileNotFoundError):
await hass.async_add_executor_job(os.walk, "/path/that/does/not/exists")
assert "Detected blocking call to walk with args" not in caplog.text
async def test_open_calls_ignored_in_tests(caplog: pytest.LogCaptureFixture) -> None:
"""Test opening a file in tests is ignored."""
assert block_async_io._IN_TESTS
block_async_io.enable()
with contextlib.suppress(FileNotFoundError):
open("/config/data_not_exist", encoding="utf8").close()
assert "Detected blocking call to open with args" not in caplog.text