From 9e231df83a61626221678dab225504d158d28d52 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 12 Jun 2024 16:44:50 -0500 Subject: [PATCH] move disable to tests only --- homeassistant/block_async_io.py | 15 +++------------ tests/common.py | 11 ++++++++++- tests/conftest.py | 8 ++++++++ tests/test_block_async_io.py | 4 +--- tests/test_bootstrap.py | 6 ++---- 5 files changed, 24 insertions(+), 20 deletions(-) diff --git a/homeassistant/block_async_io.py b/homeassistant/block_async_io.py index b2b9f14bb30..49c8af56cbc 100644 --- a/homeassistant/block_async_io.py +++ b/homeassistant/block_async_io.py @@ -153,12 +153,12 @@ class BlockedCalls: calls: set[BlockingCall] -_BLOCKED_CALLS = BlockedCalls(set()) +BLOCKED_CALLS = BlockedCalls(set()) def enable() -> None: """Enable the detection of blocking calls in the event loop.""" - if _BLOCKED_CALLS.calls: + if BLOCKED_CALLS.calls: raise RuntimeError("Blocking call detection is already enabled") loop_thread_id = threading.get_ident() @@ -174,13 +174,4 @@ def enable() -> None: loop_thread_id=loop_thread_id, ) setattr(blocking_call.object, blocking_call.function, protected_function) - _BLOCKED_CALLS.calls.add(blocking_call) - - -def disable() -> None: - """Disable the detection of blocking calls in the event loop.""" - for blocking_call in _BLOCKED_CALLS.calls: - setattr( - blocking_call.object, blocking_call.function, blocking_call.original_func - ) - _BLOCKED_CALLS.calls.clear() + BLOCKED_CALLS.calls.add(blocking_call) diff --git a/tests/common.py b/tests/common.py index 5cb82cef3ba..da39168b681 100644 --- a/tests/common.py +++ b/tests/common.py @@ -26,7 +26,7 @@ from syrupy import SnapshotAssertion from typing_extensions import AsyncGenerator, Generator import voluptuous as vol -from homeassistant import auth, bootstrap, config_entries, loader +from homeassistant import auth, block_async_io, bootstrap, config_entries, loader from homeassistant.auth import ( auth_store, models as auth_models, @@ -1767,3 +1767,12 @@ async def snapshot_platform( state = hass.states.get(entity_entry.entity_id) assert state, f"State not found for {entity_entry.entity_id}" assert state == snapshot(name=f"{entity_entry.entity_id}-state") + + +def disable_block_async_io() -> None: + """Disable the detection of blocking calls in the event loop.""" + for blocking_call in block_async_io.BLOCKED_CALLS.calls: + setattr( + blocking_call.object, blocking_call.function, blocking_call.original_func + ) + block_async_io.BLOCKED_CALLS.calls.clear() diff --git a/tests/conftest.py b/tests/conftest.py index 1d0ad3d47b3..cd98acf343d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,6 +100,7 @@ pytest.register_assert_rewrite("tests.common") from .common import ( # noqa: E402, isort:skip CLIENT_ID, INSTANCES, + disable_block_async_io, MockConfigEntry, MockUser, async_fire_mqtt_message, @@ -1814,3 +1815,10 @@ def service_calls(hass: HomeAssistant) -> Generator[None, None, list[ServiceCall def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion: """Return snapshot assertion fixture with the Home Assistant extension.""" return snapshot.use_extension(HomeAssistantSnapshotExtension) + + +@pytest.fixture +def disable_block_async_io_after() -> Generator[Any, Any, None]: + """Disable the block async io context manager.""" + yield + disable_block_async_io() diff --git a/tests/test_block_async_io.py b/tests/test_block_async_io.py index 4f9b5a11378..f1eba6f5ef4 100644 --- a/tests/test_block_async_io.py +++ b/tests/test_block_async_io.py @@ -18,10 +18,8 @@ from .common import extract_stack_to_frame @pytest.fixture(autouse=True) -def unpatch_block_async_io(): +def unpatch_block_async_io(disable_block_async_io_after): """Unpatch block_async_io after each test.""" - yield - block_async_io.disable() async def test_protect_loop_debugger_sleep(caplog: pytest.LogCaptureFixture) -> None: diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index 63e9ebdb1a7..00d0fcf36b8 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, Mock, patch import pytest from typing_extensions import Generator -from homeassistant import block_async_io, bootstrap, loader, runner +from homeassistant import bootstrap, loader, runner import homeassistant.config as config_util from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_DEBUG, SIGNAL_BOOTSTRAP_INTEGRATIONS @@ -56,10 +56,8 @@ async def apply_stop_hass(stop_hass: None) -> None: @pytest.fixture(autouse=True) -def unpatch_block_async_io(): +def unpatch_block_async_io(disable_block_async_io_after): """Unpatch block_async_io after each test.""" - yield - block_async_io.disable() @pytest.fixture(scope="module", autouse=True)