Skip to content

Draft: Ensure functions are only patched with the same patch once #418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
57 changes: 55 additions & 2 deletions aikido_zen/sinks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from wrapt import wrap_object, FunctionWrapper, when_imported
import threading

from wrapt import wrap_object, FunctionWrapper, when_imported, resolve_path
from aikido_zen.background_process.packages import ANY_VERSION, is_package_compatible
from aikido_zen.errors import AikidoException
from aikido_zen.helpers.logging import logger
Expand Down Expand Up @@ -31,11 +33,58 @@
Patches a function in the specified module with a wrapper function.
"""
try:
wrap_object(module, name, FunctionWrapper, (wrapper,))
(parent, _, original) = resolve_path(module, name)

hook_duplicate_helper = HookDuplicateHelper(parent, original, wrapper)

if not hook_duplicate_helper.is_registered():
wrap_object(module, name, FunctionWrapper, (wrapper,))
hook_duplicate_helper.register()
else:
logger.debug(
"Attempted to apply same hook twice: %s", hook_duplicate_helper.hook_id
)
except Exception as e:
logger.info("Failed to wrap %s:%s, due to: %s", module, name, e)


class HookDuplicateHelper:
AIKIDO_HOOKS_LOCK = "_aikido_hooks_lock"
AIKIDO_HOOKS_STORE = "_aikido_hooks_store"

def __init__(self, parent, original, wrapper):
self.parent = parent

# Hook id is either module+name, or set via @before, ... as _hook_id
self.hook_id = f"{wrapper.__module__}:wrapper.__name__"
if hasattr(wrapper, "_hook_id"):
self.hook_id = getattr(wrapper, "_hook_id")
self.hook_id += f":{original}"

def is_registered(self):
self._try_create_hooks_store()
if not hasattr(self.parent, self.AIKIDO_HOOKS_LOCK):
return False

Check warning on line 67 in aikido_zen/sinks/__init__.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sinks/__init__.py#L67

Added line #L67 was not covered by tests
with getattr(self.parent, self.AIKIDO_HOOKS_LOCK):
return self.hook_id in getattr(self.parent, self.AIKIDO_HOOKS_STORE)

def register(self):
self._try_create_hooks_store()
if not hasattr(self.parent, self.AIKIDO_HOOKS_LOCK):
return False

Check warning on line 74 in aikido_zen/sinks/__init__.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sinks/__init__.py#L74

Added line #L74 was not covered by tests
with getattr(self.parent, self.AIKIDO_HOOKS_LOCK):
getattr(self.parent, self.AIKIDO_HOOKS_STORE).add(self.hook_id)

def _try_create_hooks_store(self):
if hasattr(self.parent, self.AIKIDO_HOOKS_LOCK):
return
try:
setattr(self.parent, self.AIKIDO_HOOKS_LOCK, threading.Lock())
setattr(self.parent, self.AIKIDO_HOOKS_STORE, set())
except AttributeError as e:
logger.debug("Failed to create hook storage on: %s", self.parent)

Check warning on line 85 in aikido_zen/sinks/__init__.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sinks/__init__.py#L84-L85

Added lines #L84 - L85 were not covered by tests


def before(wrapper):
"""
Surrounds a patch with try-except and calls the original function at the end
Expand All @@ -52,6 +101,7 @@
)
return func(*args, **kwargs) # Call the original function

decorator._hook_id = f"{wrapper.__module__}:{wrapper.__name__}"
return decorator


Expand All @@ -71,6 +121,7 @@
)
return await func(*args, **kwargs) # Call the original function

decorator._hook_id = f"{wrapper.__module__}:{wrapper.__name__}"
return decorator


Expand All @@ -92,6 +143,7 @@
)
return func(*args, **kwargs) # Call the original function

decorator._hook_id = f"{wrapper.__module__}:{wrapper.__name__}"
return decorator


Expand All @@ -113,4 +165,5 @@

return return_value

decorator._hook_id = f"{wrapper.__module__}:{wrapper.__name__}"
return decorator
104 changes: 104 additions & 0 deletions aikido_zen/sinks/tests/patching_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from aikido_zen.helpers.get_argument import get_argument
from aikido_zen.sinks import patch_function, on_import, before_modify_return, after


@before_modify_return
def my_func_wrapper(func, instance, args, kwargs):
rv = func(*args, **kwargs)
return rv + 1


@before_modify_return
def my_func_wrapper_2(func, instance, args, kwargs):
rv = func(*args, **kwargs)
return rv * 3


def test_no_patch():
@on_import("aikido_zen.sinks.tests.utils.sample_module")
def patch(m):
pass # Do nothing

from aikido_zen.sinks.tests.utils.sample_module import my_func

assert my_func(1) == 2
assert my_func(2) == 3


def test_patch_happens_once():
@on_import("aikido_zen.sinks.tests.utils.sample_module")
def patch(m):
patch_function(m, "my_func", my_func_wrapper)

from aikido_zen.sinks.tests.utils.sample_module import my_func

assert my_func(1) == 3
assert my_func(2) == 4


def test_patch_happens_multiple():
@on_import("aikido_zen.sinks.tests.utils.sample_module_3")
def patch(m):
patch_function(m, "my_func", my_func_wrapper)
patch_function(m, "my_func", my_func_wrapper)
patch_function(m, "my_func", my_func_wrapper)

from aikido_zen.sinks.tests.utils.sample_module_3 import my_func

assert my_func(1) == 3
assert my_func(2) == 4


def test_patch_happens_multiple_but_different_function():
@on_import("aikido_zen.sinks.tests.utils.sample_module")
def patch(m):
patch_function(m, "my_func", my_func_wrapper)
patch_function(m, "my_func", my_func_wrapper_2)

from aikido_zen.sinks.tests.utils.sample_module import my_func

assert my_func(1) == (1 + 2) * 3 == 9
assert my_func(2) == (2 + 2) * 3 == 12


def test_patch_happens_multiple_but_different_order():
@on_import("aikido_zen.sinks.tests.utils.sample_module_4")
def patch(m):
patch_function(m, "my_func", my_func_wrapper_2)
patch_function(m, "my_func", my_func_wrapper)

from aikido_zen.sinks.tests.utils.sample_module_4 import my_func

assert my_func(1) == (2 * 3) + 1 == 7
assert my_func(2) == (3 * 3) + 1 == 10


def test_patch_happens_multiple_but_different_module():
# In this case, you will still have 2x the wrapper, because the parent is different.
@on_import("aikido_zen.sinks.tests.utils.sample_module_2")
def patch(m):
patch_function(m, "my_func", my_func_wrapper)

@on_import("aikido_zen.sinks.tests.utils.sample_module_5")
def patch2(m):
patch_function(m, "my_func", my_func_wrapper)

from aikido_zen.sinks.tests.utils.sample_module_2 import my_func

assert my_func(1) == 4
assert my_func(2) == 5


def test_patch_happens_multiple_different_module_class():
@on_import("aikido_zen.sinks.tests.utils.sample_module_2")
def patch1(m):
patch_function(m, "Functions.my_func", my_func_wrapper)

@on_import("aikido_zen.sinks.tests.utils.sample_module_5")
def patch2(m):
patch_function(m, "Functions.my_func", my_func_wrapper)

from aikido_zen.sinks.tests.utils.sample_module_2 import Functions

assert Functions.my_func(1) == 3
assert Functions.my_func(2) == 4
Empty file.
12 changes: 12 additions & 0 deletions aikido_zen/sinks/tests/utils/sample_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
def my_func(a):
return a + 1


def other_func(a):
return a * 2

Check warning on line 6 in aikido_zen/sinks/tests/utils/sample_module.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sinks/tests/utils/sample_module.py#L6

Added line #L6 was not covered by tests


class Functions:
@staticmethod
def my_func(a):
return a + 1

Check warning on line 12 in aikido_zen/sinks/tests/utils/sample_module.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sinks/tests/utils/sample_module.py#L12

Added line #L12 was not covered by tests
1 change: 1 addition & 0 deletions aikido_zen/sinks/tests/utils/sample_module_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from aikido_zen.sinks.tests.utils.sample_module_5 import *
12 changes: 12 additions & 0 deletions aikido_zen/sinks/tests/utils/sample_module_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
def my_func(a):
return a + 1


def other_func(a):
return a * 2

Check warning on line 6 in aikido_zen/sinks/tests/utils/sample_module_3.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sinks/tests/utils/sample_module_3.py#L6

Added line #L6 was not covered by tests


class Functions:
@staticmethod
def my_func(a):
return a + 1

Check warning on line 12 in aikido_zen/sinks/tests/utils/sample_module_3.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sinks/tests/utils/sample_module_3.py#L12

Added line #L12 was not covered by tests
12 changes: 12 additions & 0 deletions aikido_zen/sinks/tests/utils/sample_module_4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
def my_func(a):
return a + 1


def other_func(a):
return a * 2

Check warning on line 6 in aikido_zen/sinks/tests/utils/sample_module_4.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sinks/tests/utils/sample_module_4.py#L6

Added line #L6 was not covered by tests


class Functions:
@staticmethod
def my_func(a):
return a + 1

Check warning on line 12 in aikido_zen/sinks/tests/utils/sample_module_4.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sinks/tests/utils/sample_module_4.py#L12

Added line #L12 was not covered by tests
12 changes: 12 additions & 0 deletions aikido_zen/sinks/tests/utils/sample_module_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
def my_func(a):
return a + 1


def other_func(a):
return a * 2

Check warning on line 6 in aikido_zen/sinks/tests/utils/sample_module_5.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sinks/tests/utils/sample_module_5.py#L6

Added line #L6 was not covered by tests


class Functions:
@staticmethod
def my_func(a):
return a + 1
Loading