re-protect

This commit is contained in:
J. Nick Koston
2024-06-12 21:52:02 -05:00
parent 9e231df83a
commit dc2cb9a763
2 changed files with 9 additions and 7 deletions

View File

@@ -61,7 +61,7 @@ class BlockingCall:
skip_for_tests: bool skip_for_tests: bool
BLOCKING_CALLS: tuple[BlockingCall, ...] = ( _BLOCKING_CALLS: tuple[BlockingCall, ...] = (
BlockingCall( BlockingCall(
original_func=HTTPConnection.putrequest, original_func=HTTPConnection.putrequest,
object=HTTPConnection, object=HTTPConnection,
@@ -153,16 +153,17 @@ class BlockedCalls:
calls: set[BlockingCall] calls: set[BlockingCall]
BLOCKED_CALLS = BlockedCalls(set()) _BLOCKED_CALLS = BlockedCalls(set())
def enable() -> None: def enable() -> None:
"""Enable the detection of blocking calls in the event loop.""" """Enable the detection of blocking calls in the event loop."""
if BLOCKED_CALLS.calls: calls = _BLOCKED_CALLS.calls
if calls:
raise RuntimeError("Blocking call detection is already enabled") raise RuntimeError("Blocking call detection is already enabled")
loop_thread_id = threading.get_ident() loop_thread_id = threading.get_ident()
for blocking_call in BLOCKING_CALLS: for blocking_call in _BLOCKING_CALLS:
if _IN_TESTS and blocking_call.skip_for_tests: if _IN_TESTS and blocking_call.skip_for_tests:
continue continue
@@ -174,4 +175,4 @@ def enable() -> None:
loop_thread_id=loop_thread_id, loop_thread_id=loop_thread_id,
) )
setattr(blocking_call.object, blocking_call.function, protected_function) setattr(blocking_call.object, blocking_call.function, protected_function)
BLOCKED_CALLS.calls.add(blocking_call) calls.add(blocking_call)

View File

@@ -1771,8 +1771,9 @@ async def snapshot_platform(
def disable_block_async_io() -> None: def disable_block_async_io() -> None:
"""Disable the detection of blocking calls in the event loop.""" """Disable the detection of blocking calls in the event loop."""
for blocking_call in block_async_io.BLOCKED_CALLS.calls: calls = block_async_io._BLOCKED_CALLS.calls
for blocking_call in calls:
setattr( setattr(
blocking_call.object, blocking_call.function, blocking_call.original_func blocking_call.object, blocking_call.function, blocking_call.original_func
) )
block_async_io.BLOCKED_CALLS.calls.clear() calls.clear()