move disable to tests only

This commit is contained in:
J. Nick Koston
2024-06-12 16:44:50 -05:00
parent 219bb697da
commit 9e231df83a
5 changed files with 24 additions and 20 deletions

View File

@@ -153,12 +153,12 @@ class BlockedCalls:
calls: set[BlockingCall] calls: set[BlockingCall]
_BLOCKED_CALLS = BlockedCalls(set()) BLOCKED_CALLS = BlockedCalls(set())
def enable() -> None: def enable() -> None:
"""Enable the detection of blocking calls in the event loop.""" """Enable the detection of blocking calls in the event loop."""
if _BLOCKED_CALLS.calls: if BLOCKED_CALLS.calls:
raise RuntimeError("Blocking call detection is already enabled") raise RuntimeError("Blocking call detection is already enabled")
loop_thread_id = threading.get_ident() loop_thread_id = threading.get_ident()
@@ -174,13 +174,4 @@ def enable() -> None:
loop_thread_id=loop_thread_id, loop_thread_id=loop_thread_id,
) )
setattr(blocking_call.object, blocking_call.function, protected_function) setattr(blocking_call.object, blocking_call.function, protected_function)
_BLOCKED_CALLS.calls.add(blocking_call) BLOCKED_CALLS.calls.add(blocking_call)
def disable() -> None:
"""Disable the detection of blocking calls in the event loop."""
for blocking_call in _BLOCKED_CALLS.calls:
setattr(
blocking_call.object, blocking_call.function, blocking_call.original_func
)
_BLOCKED_CALLS.calls.clear()

View File

@@ -26,7 +26,7 @@ from syrupy import SnapshotAssertion
from typing_extensions import AsyncGenerator, Generator from typing_extensions import AsyncGenerator, Generator
import voluptuous as vol import voluptuous as vol
from homeassistant import auth, bootstrap, config_entries, loader from homeassistant import auth, block_async_io, bootstrap, config_entries, loader
from homeassistant.auth import ( from homeassistant.auth import (
auth_store, auth_store,
models as auth_models, models as auth_models,
@@ -1767,3 +1767,12 @@ async def snapshot_platform(
state = hass.states.get(entity_entry.entity_id) state = hass.states.get(entity_entry.entity_id)
assert state, f"State not found for {entity_entry.entity_id}" assert state, f"State not found for {entity_entry.entity_id}"
assert state == snapshot(name=f"{entity_entry.entity_id}-state") assert state == snapshot(name=f"{entity_entry.entity_id}-state")
def disable_block_async_io() -> None:
"""Disable the detection of blocking calls in the event loop."""
for blocking_call in block_async_io.BLOCKED_CALLS.calls:
setattr(
blocking_call.object, blocking_call.function, blocking_call.original_func
)
block_async_io.BLOCKED_CALLS.calls.clear()

View File

@@ -100,6 +100,7 @@ pytest.register_assert_rewrite("tests.common")
from .common import ( # noqa: E402, isort:skip from .common import ( # noqa: E402, isort:skip
CLIENT_ID, CLIENT_ID,
INSTANCES, INSTANCES,
disable_block_async_io,
MockConfigEntry, MockConfigEntry,
MockUser, MockUser,
async_fire_mqtt_message, async_fire_mqtt_message,
@@ -1814,3 +1815,10 @@ def service_calls(hass: HomeAssistant) -> Generator[None, None, list[ServiceCall
def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion: def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion:
"""Return snapshot assertion fixture with the Home Assistant extension.""" """Return snapshot assertion fixture with the Home Assistant extension."""
return snapshot.use_extension(HomeAssistantSnapshotExtension) return snapshot.use_extension(HomeAssistantSnapshotExtension)
@pytest.fixture
def disable_block_async_io_after() -> Generator[Any, Any, None]:
"""Disable the block async io context manager."""
yield
disable_block_async_io()

View File

@@ -18,10 +18,8 @@ from .common import extract_stack_to_frame
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def unpatch_block_async_io(): def unpatch_block_async_io(disable_block_async_io_after):
"""Unpatch block_async_io after each test.""" """Unpatch block_async_io after each test."""
yield
block_async_io.disable()
async def test_protect_loop_debugger_sleep(caplog: pytest.LogCaptureFixture) -> None: async def test_protect_loop_debugger_sleep(caplog: pytest.LogCaptureFixture) -> None:

View File

@@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from typing_extensions import Generator from typing_extensions import Generator
from homeassistant import block_async_io, bootstrap, loader, runner from homeassistant import bootstrap, loader, runner
import homeassistant.config as config_util import homeassistant.config as config_util
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_DEBUG, SIGNAL_BOOTSTRAP_INTEGRATIONS from homeassistant.const import CONF_DEBUG, SIGNAL_BOOTSTRAP_INTEGRATIONS
@@ -56,10 +56,8 @@ async def apply_stop_hass(stop_hass: None) -> None:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def unpatch_block_async_io(): def unpatch_block_async_io(disable_block_async_io_after):
"""Unpatch block_async_io after each test.""" """Unpatch block_async_io after each test."""
yield
block_async_io.disable()
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)