we are patching objects, make it more generic

This commit is contained in:
J. Nick Koston
2024-06-12 16:35:12 -05:00
parent 4c6c45d8e0
commit 2ab5f48bd6

View File

@@ -11,7 +11,6 @@ import os
import sys import sys
import threading import threading
import time import time
from types import ModuleType
from typing import Any from typing import Any
from .helpers.frame import get_current_frame from .helpers.frame import get_current_frame
@@ -54,7 +53,7 @@ class BlockingCall:
"""Class to hold information about a blocking call.""" """Class to hold information about a blocking call."""
original_func: Callable original_func: Callable
module: ModuleType object: object
function: str function: str
check_allowed: Callable[[dict[str, Any]], bool] | None check_allowed: Callable[[dict[str, Any]], bool] | None
strict: bool strict: bool
@@ -65,7 +64,7 @@ class BlockingCall:
BLOCKING_CALLS: tuple[BlockingCall, ...] = ( BLOCKING_CALLS: tuple[BlockingCall, ...] = (
BlockingCall( BlockingCall(
original_func=HTTPConnection.putrequest, original_func=HTTPConnection.putrequest,
module=HTTPConnection, # type: ignore[arg-type] object=HTTPConnection,
function="putrequest", function="putrequest",
check_allowed=None, check_allowed=None,
strict=True, strict=True,
@@ -74,7 +73,7 @@ BLOCKING_CALLS: tuple[BlockingCall, ...] = (
), ),
BlockingCall( BlockingCall(
original_func=time.sleep, original_func=time.sleep,
module=time, object=time,
function="sleep", function="sleep",
check_allowed=_check_sleep_call_allowed, check_allowed=_check_sleep_call_allowed,
strict=True, strict=True,
@@ -83,7 +82,7 @@ BLOCKING_CALLS: tuple[BlockingCall, ...] = (
), ),
BlockingCall( BlockingCall(
original_func=glob.glob, original_func=glob.glob,
module=glob, object=glob,
function="glob", function="glob",
check_allowed=None, check_allowed=None,
strict=False, strict=False,
@@ -92,7 +91,7 @@ BLOCKING_CALLS: tuple[BlockingCall, ...] = (
), ),
BlockingCall( BlockingCall(
original_func=glob.iglob, original_func=glob.iglob,
module=glob, object=glob,
function="iglob", function="iglob",
check_allowed=None, check_allowed=None,
strict=False, strict=False,
@@ -101,7 +100,7 @@ BLOCKING_CALLS: tuple[BlockingCall, ...] = (
), ),
BlockingCall( BlockingCall(
original_func=os.walk, original_func=os.walk,
module=os, object=os,
function="walk", function="walk",
check_allowed=None, check_allowed=None,
strict=False, strict=False,
@@ -110,7 +109,7 @@ BLOCKING_CALLS: tuple[BlockingCall, ...] = (
), ),
BlockingCall( BlockingCall(
original_func=os.listdir, original_func=os.listdir,
module=os, object=os,
function="listdir", function="listdir",
check_allowed=None, check_allowed=None,
strict=False, strict=False,
@@ -119,7 +118,7 @@ BLOCKING_CALLS: tuple[BlockingCall, ...] = (
), ),
BlockingCall( BlockingCall(
original_func=os.scandir, original_func=os.scandir,
module=os, object=os,
function="scandir", function="scandir",
check_allowed=None, check_allowed=None,
strict=False, strict=False,
@@ -128,7 +127,7 @@ BLOCKING_CALLS: tuple[BlockingCall, ...] = (
), ),
BlockingCall( BlockingCall(
original_func=builtins.open, original_func=builtins.open,
module=builtins, object=builtins,
function="open", function="open",
check_allowed=_check_file_allowed, check_allowed=_check_file_allowed,
strict=False, strict=False,
@@ -137,7 +136,7 @@ BLOCKING_CALLS: tuple[BlockingCall, ...] = (
), ),
BlockingCall( BlockingCall(
original_func=importlib.import_module, original_func=importlib.import_module,
module=importlib, object=importlib,
function="import_module", function="import_module",
check_allowed=_check_import_call_allowed, check_allowed=_check_import_call_allowed,
strict=False, strict=False,
@@ -174,7 +173,7 @@ def enable() -> None:
check_allowed=blocking_call.check_allowed, check_allowed=blocking_call.check_allowed,
loop_thread_id=loop_thread_id, loop_thread_id=loop_thread_id,
) )
setattr(blocking_call.module, blocking_call.function, protected_function) setattr(blocking_call.object, blocking_call.function, protected_function)
_BLOCKED_CALLS.calls.add(blocking_call) _BLOCKED_CALLS.calls.add(blocking_call)
@@ -182,6 +181,6 @@ def disable() -> None:
"""Disable the detection of blocking calls in the event loop.""" """Disable the detection of blocking calls in the event loop."""
for blocking_call in _BLOCKED_CALLS.calls: for blocking_call in _BLOCKED_CALLS.calls:
setattr( setattr(
blocking_call.module, blocking_call.function, blocking_call.original_func blocking_call.object, blocking_call.function, blocking_call.original_func
) )
_BLOCKED_CALLS.calls.clear() _BLOCKED_CALLS.calls.clear()