diff --git a/aikido_zen/sinks/__init__.py b/aikido_zen/sinks/__init__.py index 5d3278b6..5fcfd897 100644 --- a/aikido_zen/sinks/__init__.py +++ b/aikido_zen/sinks/__init__.py @@ -1,4 +1,6 @@ -from wrapt import wrap_object, FunctionWrapper, when_imported +import threading + +from wrapt import wrap_object, FunctionWrapper, when_imported, resolve_path from aikido_zen.background_process.packages import ANY_VERSION, is_package_compatible from aikido_zen.errors import AikidoException from aikido_zen.helpers.logging import logger @@ -31,11 +33,58 @@ def patch_function(module, name, wrapper): Patches a function in the specified module with a wrapper function. """ try: - wrap_object(module, name, FunctionWrapper, (wrapper,)) + (parent, _, original) = resolve_path(module, name) + + hook_duplicate_helper = HookDuplicateHelper(parent, original, wrapper) + + if not hook_duplicate_helper.is_registered(): + wrap_object(module, name, FunctionWrapper, (wrapper,)) + hook_duplicate_helper.register() + else: + logger.debug( + "Attempted to apply same hook twice: %s", hook_duplicate_helper.hook_id + ) except Exception as e: logger.info("Failed to wrap %s:%s, due to: %s", module, name, e) +class HookDuplicateHelper: + AIKIDO_HOOKS_LOCK = "_aikido_hooks_lock" + AIKIDO_HOOKS_STORE = "_aikido_hooks_store" + + def __init__(self, parent, original, wrapper): + self.parent = parent + + # Hook id is either module+name, or set via @before, ... as _hook_id + self.hook_id = f"{wrapper.__module__}:wrapper.__name__" + if hasattr(wrapper, "_hook_id"): + self.hook_id = getattr(wrapper, "_hook_id") + self.hook_id += f":{original}" + + def is_registered(self): + self._try_create_hooks_store() + if not hasattr(self.parent, self.AIKIDO_HOOKS_LOCK): + return False + with getattr(self.parent, self.AIKIDO_HOOKS_LOCK): + return self.hook_id in getattr(self.parent, self.AIKIDO_HOOKS_STORE) + + def register(self): + self._try_create_hooks_store() + if not hasattr(self.parent, self.AIKIDO_HOOKS_LOCK): + return False + with getattr(self.parent, self.AIKIDO_HOOKS_LOCK): + getattr(self.parent, self.AIKIDO_HOOKS_STORE).add(self.hook_id) + + def _try_create_hooks_store(self): + if hasattr(self.parent, self.AIKIDO_HOOKS_LOCK): + return + try: + setattr(self.parent, self.AIKIDO_HOOKS_LOCK, threading.Lock()) + setattr(self.parent, self.AIKIDO_HOOKS_STORE, set()) + except AttributeError as e: + logger.debug("Failed to create hook storage on: %s", self.parent) + + def before(wrapper): """ Surrounds a patch with try-except and calls the original function at the end @@ -52,6 +101,7 @@ def decorator(func, instance, args, kwargs): ) return func(*args, **kwargs) # Call the original function + decorator._hook_id = f"{wrapper.__module__}:{wrapper.__name__}" return decorator @@ -71,6 +121,7 @@ async def decorator(func, instance, args, kwargs): ) return await func(*args, **kwargs) # Call the original function + decorator._hook_id = f"{wrapper.__module__}:{wrapper.__name__}" return decorator @@ -92,6 +143,7 @@ def decorator(func, instance, args, kwargs): ) return func(*args, **kwargs) # Call the original function + decorator._hook_id = f"{wrapper.__module__}:{wrapper.__name__}" return decorator @@ -113,4 +165,5 @@ def decorator(func, instance, args, kwargs): return return_value + decorator._hook_id = f"{wrapper.__module__}:{wrapper.__name__}" return decorator diff --git a/aikido_zen/sinks/tests/patching_test.py b/aikido_zen/sinks/tests/patching_test.py new file mode 100644 index 00000000..d04a8bb3 --- /dev/null +++ b/aikido_zen/sinks/tests/patching_test.py @@ -0,0 +1,104 @@ +from aikido_zen.helpers.get_argument import get_argument +from aikido_zen.sinks import patch_function, on_import, before_modify_return, after + + +@before_modify_return +def my_func_wrapper(func, instance, args, kwargs): + rv = func(*args, **kwargs) + return rv + 1 + + +@before_modify_return +def my_func_wrapper_2(func, instance, args, kwargs): + rv = func(*args, **kwargs) + return rv * 3 + + +def test_no_patch(): + @on_import("aikido_zen.sinks.tests.utils.sample_module") + def patch(m): + pass # Do nothing + + from aikido_zen.sinks.tests.utils.sample_module import my_func + + assert my_func(1) == 2 + assert my_func(2) == 3 + + +def test_patch_happens_once(): + @on_import("aikido_zen.sinks.tests.utils.sample_module") + def patch(m): + patch_function(m, "my_func", my_func_wrapper) + + from aikido_zen.sinks.tests.utils.sample_module import my_func + + assert my_func(1) == 3 + assert my_func(2) == 4 + + +def test_patch_happens_multiple(): + @on_import("aikido_zen.sinks.tests.utils.sample_module_3") + def patch(m): + patch_function(m, "my_func", my_func_wrapper) + patch_function(m, "my_func", my_func_wrapper) + patch_function(m, "my_func", my_func_wrapper) + + from aikido_zen.sinks.tests.utils.sample_module_3 import my_func + + assert my_func(1) == 3 + assert my_func(2) == 4 + + +def test_patch_happens_multiple_but_different_function(): + @on_import("aikido_zen.sinks.tests.utils.sample_module") + def patch(m): + patch_function(m, "my_func", my_func_wrapper) + patch_function(m, "my_func", my_func_wrapper_2) + + from aikido_zen.sinks.tests.utils.sample_module import my_func + + assert my_func(1) == (1 + 2) * 3 == 9 + assert my_func(2) == (2 + 2) * 3 == 12 + + +def test_patch_happens_multiple_but_different_order(): + @on_import("aikido_zen.sinks.tests.utils.sample_module_4") + def patch(m): + patch_function(m, "my_func", my_func_wrapper_2) + patch_function(m, "my_func", my_func_wrapper) + + from aikido_zen.sinks.tests.utils.sample_module_4 import my_func + + assert my_func(1) == (2 * 3) + 1 == 7 + assert my_func(2) == (3 * 3) + 1 == 10 + + +def test_patch_happens_multiple_but_different_module(): + # In this case, you will still have 2x the wrapper, because the parent is different. + @on_import("aikido_zen.sinks.tests.utils.sample_module_2") + def patch(m): + patch_function(m, "my_func", my_func_wrapper) + + @on_import("aikido_zen.sinks.tests.utils.sample_module_5") + def patch2(m): + patch_function(m, "my_func", my_func_wrapper) + + from aikido_zen.sinks.tests.utils.sample_module_2 import my_func + + assert my_func(1) == 4 + assert my_func(2) == 5 + + +def test_patch_happens_multiple_different_module_class(): + @on_import("aikido_zen.sinks.tests.utils.sample_module_2") + def patch1(m): + patch_function(m, "Functions.my_func", my_func_wrapper) + + @on_import("aikido_zen.sinks.tests.utils.sample_module_5") + def patch2(m): + patch_function(m, "Functions.my_func", my_func_wrapper) + + from aikido_zen.sinks.tests.utils.sample_module_2 import Functions + + assert Functions.my_func(1) == 3 + assert Functions.my_func(2) == 4 diff --git a/aikido_zen/sinks/tests/utils/__init__.py b/aikido_zen/sinks/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aikido_zen/sinks/tests/utils/sample_module.py b/aikido_zen/sinks/tests/utils/sample_module.py new file mode 100644 index 00000000..51d923f1 --- /dev/null +++ b/aikido_zen/sinks/tests/utils/sample_module.py @@ -0,0 +1,12 @@ +def my_func(a): + return a + 1 + + +def other_func(a): + return a * 2 + + +class Functions: + @staticmethod + def my_func(a): + return a + 1 diff --git a/aikido_zen/sinks/tests/utils/sample_module_2.py b/aikido_zen/sinks/tests/utils/sample_module_2.py new file mode 100644 index 00000000..4b487e0b --- /dev/null +++ b/aikido_zen/sinks/tests/utils/sample_module_2.py @@ -0,0 +1 @@ +from aikido_zen.sinks.tests.utils.sample_module_5 import * diff --git a/aikido_zen/sinks/tests/utils/sample_module_3.py b/aikido_zen/sinks/tests/utils/sample_module_3.py new file mode 100644 index 00000000..51d923f1 --- /dev/null +++ b/aikido_zen/sinks/tests/utils/sample_module_3.py @@ -0,0 +1,12 @@ +def my_func(a): + return a + 1 + + +def other_func(a): + return a * 2 + + +class Functions: + @staticmethod + def my_func(a): + return a + 1 diff --git a/aikido_zen/sinks/tests/utils/sample_module_4.py b/aikido_zen/sinks/tests/utils/sample_module_4.py new file mode 100644 index 00000000..51d923f1 --- /dev/null +++ b/aikido_zen/sinks/tests/utils/sample_module_4.py @@ -0,0 +1,12 @@ +def my_func(a): + return a + 1 + + +def other_func(a): + return a * 2 + + +class Functions: + @staticmethod + def my_func(a): + return a + 1 diff --git a/aikido_zen/sinks/tests/utils/sample_module_5.py b/aikido_zen/sinks/tests/utils/sample_module_5.py new file mode 100644 index 00000000..51d923f1 --- /dev/null +++ b/aikido_zen/sinks/tests/utils/sample_module_5.py @@ -0,0 +1,12 @@ +def my_func(a): + return a + 1 + + +def other_func(a): + return a * 2 + + +class Functions: + @staticmethod + def my_func(a): + return a + 1