Ensure asyncio blocking checks are undone after tests run

This commit is contained in:
J. Nick Koston
2024-06-12 16:23:32 -05:00
parent 51891ff8e2
commit ed6cac4e4f
2 changed files with 164 additions and 51 deletions

View File

@@ -1,7 +1,9 @@
"""Block blocking calls being done in asyncio.""" """Block blocking calls being done in asyncio."""
import builtins import builtins
from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass
import glob import glob
from http.client import HTTPConnection from http.client import HTTPConnection
import importlib import importlib
@@ -9,6 +11,7 @@ import os
import sys import sys
import threading import threading
import time import time
from types import ModuleType
from typing import Any from typing import Any
from .helpers.frame import get_current_frame from .helpers.frame import get_current_frame
@@ -46,53 +49,139 @@ def _check_sleep_call_allowed(mapped_args: dict[str, Any]) -> bool:
return False return False
@dataclass(slots=True, frozen=True)
class BlockingCall:
"""Class to hold information about a blocking call."""
original_func: Callable
module: ModuleType
function: str
check_allowed: Callable[[dict[str, Any]], bool] | None
strict: bool
strict_core: bool
skip_for_tests: bool
BLOCKING_CALLS: tuple[BlockingCall, ...] = (
BlockingCall(
original_func=HTTPConnection.putrequest,
module=HTTPConnection, # type: ignore[arg-type]
function="putrequest",
check_allowed=None,
strict=True,
strict_core=True,
skip_for_tests=False,
),
BlockingCall(
original_func=time.sleep,
module=time,
function="sleep",
check_allowed=_check_sleep_call_allowed,
strict=True,
strict_core=True,
skip_for_tests=False,
),
BlockingCall(
original_func=glob.glob,
module=glob,
function="glob",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=False,
),
BlockingCall(
original_func=glob.iglob,
module=glob,
function="iglob",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=False,
),
BlockingCall(
original_func=os.walk,
module=os,
function="walk",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=False,
),
BlockingCall(
original_func=os.listdir,
module=os,
function="listdir",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=True,
),
BlockingCall(
original_func=os.scandir,
module=os,
function="scandir",
check_allowed=None,
strict=False,
strict_core=False,
skip_for_tests=True,
),
BlockingCall(
original_func=builtins.open,
module=builtins,
function="open",
check_allowed=_check_file_allowed,
strict=False,
strict_core=False,
skip_for_tests=True,
),
BlockingCall(
original_func=importlib.import_module,
module=importlib,
function="import_module",
check_allowed=_check_import_call_allowed,
strict=False,
strict_core=False,
skip_for_tests=True,
),
)
@dataclass(slots=True)
class BlockedCalls:
"""Class to track which calls are blocked."""
calls: set[BlockingCall]
_BLOCKED_CALLS = BlockedCalls(set())
def enable() -> None: def enable() -> None:
"""Enable the detection of blocking calls in the event loop.""" """Enable the detection of blocking calls in the event loop."""
if _BLOCKED_CALLS.calls:
return
loop_thread_id = threading.get_ident() loop_thread_id = threading.get_ident()
# Prevent urllib3 and requests doing I/O in event loop for blocking_call in BLOCKING_CALLS:
HTTPConnection.putrequest = protect_loop( # type: ignore[method-assign] if _IN_TESTS and blocking_call.skip_for_tests:
HTTPConnection.putrequest, loop_thread_id=loop_thread_id continue
)
# Prevent sleeping in event loop. protected_function = protect_loop(
time.sleep = protect_loop( blocking_call.original_func,
time.sleep, strict=blocking_call.strict,
check_allowed=_check_sleep_call_allowed, strict_core=blocking_call.strict_core,
check_allowed=blocking_call.check_allowed,
loop_thread_id=loop_thread_id, loop_thread_id=loop_thread_id,
) )
setattr(blocking_call.module, blocking_call.function, protected_function)
_BLOCKED_CALLS.calls.add(blocking_call)
glob.glob = protect_loop(
glob.glob, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
glob.iglob = protect_loop(
glob.iglob, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
os.walk = protect_loop(
os.walk, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
if not _IN_TESTS: def disable() -> None:
# Prevent files being opened inside the event loop """Disable the detection of blocking calls in the event loop."""
os.listdir = protect_loop( # type: ignore[assignment] for blocking_call in _BLOCKED_CALLS.calls:
os.listdir, strict_core=False, strict=False, loop_thread_id=loop_thread_id setattr(
) blocking_call.module, blocking_call.function, blocking_call.original_func
os.scandir = protect_loop( # type: ignore[assignment]
os.scandir, strict_core=False, strict=False, loop_thread_id=loop_thread_id
)
builtins.open = protect_loop( # type: ignore[assignment]
builtins.open,
strict_core=False,
strict=False,
check_allowed=_check_file_allowed,
loop_thread_id=loop_thread_id,
)
# unittest uses `importlib.import_module` to do mocking
# so we cannot protect it if we are running tests
importlib.import_module = protect_loop(
importlib.import_module,
strict_core=False,
strict=False,
check_allowed=_check_import_call_allowed,
loop_thread_id=loop_thread_id,
) )
_BLOCKED_CALLS.calls.clear()

View File

@@ -17,6 +17,13 @@ from homeassistant.core import HomeAssistant
from .common import extract_stack_to_frame from .common import extract_stack_to_frame
@pytest.fixture(autouse=True)
def unpatch_block_async_io():
"""Unpatch block_async_io after each test."""
yield
block_async_io.disable()
async def test_protect_loop_debugger_sleep(caplog: pytest.LogCaptureFixture) -> None: async def test_protect_loop_debugger_sleep(caplog: pytest.LogCaptureFixture) -> None:
"""Test time.sleep injected by the debugger is not reported.""" """Test time.sleep injected by the debugger is not reported."""
block_async_io.enable() block_async_io.enable()
@@ -214,6 +221,7 @@ async def test_protect_loop_open(caplog: pytest.LogCaptureFixture) -> None:
async def test_protect_open(caplog: pytest.LogCaptureFixture) -> None: async def test_protect_open(caplog: pytest.LogCaptureFixture) -> None:
"""Test opening a file in the event loop logs.""" """Test opening a file in the event loop logs."""
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable() block_async_io.enable()
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
open("/config/data_not_exist", encoding="utf8").close() open("/config/data_not_exist", encoding="utf8").close()
@@ -231,6 +239,7 @@ async def test_protect_open(caplog: pytest.LogCaptureFixture) -> None:
) )
async def test_protect_open_path(path: Any, caplog: pytest.LogCaptureFixture) -> None: async def test_protect_open_path(path: Any, caplog: pytest.LogCaptureFixture) -> None:
"""Test opening a file by path in the event loop logs.""" """Test opening a file by path in the event loop logs."""
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable() block_async_io.enable()
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
open(path, encoding="utf8").close() open(path, encoding="utf8").close()
@@ -242,6 +251,7 @@ async def test_protect_loop_glob(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test glob calls in the loop are logged.""" """Test glob calls in the loop are logged."""
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable() block_async_io.enable()
glob.glob("/dev/null") glob.glob("/dev/null")
assert "Detected blocking call to glob with args" in caplog.text assert "Detected blocking call to glob with args" in caplog.text
@@ -254,6 +264,7 @@ async def test_protect_loop_iglob(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test iglob calls in the loop are logged.""" """Test iglob calls in the loop are logged."""
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable() block_async_io.enable()
glob.iglob("/dev/null") glob.iglob("/dev/null")
assert "Detected blocking call to iglob with args" in caplog.text assert "Detected blocking call to iglob with args" in caplog.text
@@ -266,6 +277,7 @@ async def test_protect_loop_scandir(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test glob calls in the loop are logged.""" """Test glob calls in the loop are logged."""
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable() block_async_io.enable()
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
os.scandir("/path/that/does/not/exists") os.scandir("/path/that/does/not/exists")
@@ -280,6 +292,7 @@ async def test_protect_loop_listdir(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test listdir calls in the loop are logged.""" """Test listdir calls in the loop are logged."""
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable() block_async_io.enable()
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
os.listdir("/path/that/does/not/exists") os.listdir("/path/that/does/not/exists")
@@ -293,7 +306,8 @@ async def test_protect_loop_listdir(
async def test_protect_loop_walk( async def test_protect_loop_walk(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test glob calls in the loop are logged.""" """Test os.walk calls in the loop are logged."""
with patch.object(block_async_io, "_IN_TESTS", False):
block_async_io.enable() block_async_io.enable()
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
os.walk("/path/that/does/not/exists") os.walk("/path/that/does/not/exists")
@@ -302,3 +316,13 @@ async def test_protect_loop_walk(
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
await hass.async_add_executor_job(os.walk, "/path/that/does/not/exists") await hass.async_add_executor_job(os.walk, "/path/that/does/not/exists")
assert "Detected blocking call to walk with args" not in caplog.text assert "Detected blocking call to walk with args" not in caplog.text
async def test_open_calls_ignored_in_tests(caplog: pytest.LogCaptureFixture) -> None:
"""Test opening a file in tests is ignored."""
assert block_async_io._IN_TESTS
block_async_io.enable()
with contextlib.suppress(FileNotFoundError):
open("/config/data_not_exist", encoding="utf8").close()
assert "Detected blocking call to open with args" not in caplog.text