From 2c8125126987f6c64a54e8d6547bc3dffb4f5682 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 12:48:45 +0200 Subject: [PATCH 01/54] u Update DTAT399F_backend._config.ipynb Update DTAT399F_backend._config.ipynb Update _config.py Update __init__.py Update __init__.py Update test__config.py Update core.py Update core.py Update core.py Update test_core.py Update core.py --- deeptrack/__init__.py | 4 +- deeptrack/backend/__init__.py | 22 +++-- deeptrack/backend/_config.py | 32 ++++---- deeptrack/backend/core.py | 80 +++++++++++-------- deeptrack/tests/backend/__init__.py | 4 +- deeptrack/tests/backend/test__config.py | 14 ++++ deeptrack/tests/backend/test_core.py | 26 +++--- .../DTAT399F_backend._config.ipynb | 8 +- 8 files changed, 105 insertions(+), 85 deletions(-) diff --git a/deeptrack/__init__.py b/deeptrack/__init__.py index 53da83c9d..0c32be3b2 100644 --- a/deeptrack/__init__.py +++ b/deeptrack/__init__.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from pint import UnitRegistry -from .backend.pint_definition import pint_definitions +from deeptrack.backend.pint_definition import pint_definitions import warnings import importlib.util @@ -26,6 +26,8 @@ # Create a unit registry with custom pixel-related units. units = UnitRegistry(pint_definitions.split("\n")) +from deeptrack.backend import * + from deeptrack.features import * from deeptrack.aberrations import * from deeptrack.augmentations import * diff --git a/deeptrack/backend/__init__.py b/deeptrack/backend/__init__.py index 4ea44c9dd..c790ffb5f 100644 --- a/deeptrack/backend/__init__.py +++ b/deeptrack/backend/__init__.py @@ -1,16 +1,12 @@ -from deeptrack.backend._config import ( - config, - OPENCV_AVAILABLE, - TORCH_AVAILABLE, - xp, -) -from deeptrack.backend import core - +from deeptrack.backend._config import * +from deeptrack.backend.core import * __all__ = [ - "config", - "core", - "OPENCV_AVAILABLE", - "TORCH_AVAILABLE", - "xp", + "config", # deeptrack.backend._config + "OPENCV_AVAILABLE", # deeptrack.backend._config + "TORCH_AVAILABLE", # deeptrack.backend._config + "xp", # deeptrack.backend._config + "DeepTrackDataDict", # deeptrack.backend.core + "DeepTrackDataObject", # deeptrack.backend.core + "DeepTrackNode", # deeptrack.backend.core ] diff --git a/deeptrack/backend/_config.py b/deeptrack/backend/_config.py index 431dfb5aa..d789389e1 100644 --- a/deeptrack/backend/_config.py +++ b/deeptrack/backend/_config.py @@ -65,7 +65,7 @@ Import the global config object and the xp proxy for backend-agnostic code: ->>> from deeptrack.backend._config import config, xp +>>> from deeptrack.backend import config, xp Check the default backend and device: @@ -107,13 +107,13 @@ Check PyTorch availability: ->>> from deeptrack.backend._config import TORCH_AVAILABLE +>>> from deeptrack.backend import TORCH_AVAILABLE >>> print(TORCH_AVAILABLE) Check OpenCV availability: ->>> from deeptrack.backend._config import OPENCV_AVAILABLE +>>> from deeptrack.backend import OPENCV_AVAILABLE >>> print(OPENCV_AVAILABLE) @@ -646,7 +646,7 @@ class Config: Create the singleton configuration object and check its defaults: - >>> from deeptrack.backend._config import config + >>> from deeptrack.backend import config >>> print(config.get_backend()) # Output: 'numpy' >>> print(config.get_device()) # Output: 'cpu' @@ -660,7 +660,7 @@ class Config: Use the xp proxy to create arrays/tensors: - >>> from deeptrack.backend._config import xp + >>> from deeptrack.backend import xp >>> config.set_backend_numpy() >>> array = xp.arange(5) @@ -726,7 +726,7 @@ def set_device( Import the singleton configuration object: - >>> from deeptrack.backend._config import config + >>> from deeptrack.backend import config Set device to CPU (works with both NumPy and PyTorch backends): @@ -778,7 +778,7 @@ def get_device(self: Config) -> str | torch.device: -------- Import the singleton configuration object: - >>> from deeptrack.backend._config import config + >>> from deeptrack.backend import config Get the current device: @@ -795,7 +795,7 @@ def set_backend_numpy(self: Config) -> None: -------- Import the singleton configuration object: - >>> from deeptrack.backend._config import config + >>> from deeptrack.backend import config Set the backend to NumPy: @@ -804,7 +804,7 @@ def set_backend_numpy(self: Config) -> None: NumPy backend enables use of standard NumPy arrays via the xp proxy: - >>> from deeptrack.backend._config import xp + >>> from deeptrack.backend import xp >>> array = xp.arange(5) >>> print(type(array)) # Output: @@ -819,7 +819,7 @@ def set_backend_torch(self: Config) -> None: -------- Import the singleton configuration object: - >>> from deeptrack.backend._config import config + >>> from deeptrack.backend import config Set the backend to PyTorch: @@ -828,7 +828,7 @@ def set_backend_torch(self: Config) -> None: PyTorch backend enables use of PyTorch tensors via the xp proxy: - >>> from deeptrack.backend._config import xp + >>> from deeptrack.backend import xp >>> tensor = xp.arange(5) >>> print(type(tensor)) # Output: @@ -852,7 +852,7 @@ def set_backend( -------- Import the singleton configuration object: - >>> from deeptrack.backend._config import config + >>> from deeptrack.backend import config Set the backend to NumPy: @@ -866,7 +866,7 @@ def set_backend( Switch between backends as needed in your workflow using the xp proxy: - >>> from deeptrack.backend._config import xp + >>> from deeptrack.backend import xp >>> config.set_backend("numpy") >>> array = xp.arange(4) @@ -899,7 +899,7 @@ def get_backend(self: Config) -> Literal["numpy", "torch"]: -------- Import the singleton configuration object: - >>> from deeptrack.backend._config import config + >>> from deeptrack.backend import config Get the current backend: @@ -932,7 +932,7 @@ def with_backend( -------- Import the singleton configuration object: - >>> from deeptrack.backend._config import config + >>> from deeptrack.backend import config Temporarily switch to the NumPy backend for a block of code: @@ -946,7 +946,7 @@ def with_backend( Temporarily switch to the PyTorch backend inside a function: - >>> from deeptrack.backend._config import xp + >>> from deeptrack.backend import xp >>> config.set_backend("numpy") diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index dfd656ebf..9b71340af 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -67,18 +67,24 @@ """ +from __future__ import annotations + import operator # Operator overloading for computation nodes. from weakref import WeakSet # Manages relationships between nodes without # creating circular dependencies. +from typing import Any, Callable, Iterator + +from deeptrack.utils import get_kwarg_names -from typing import ( - Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union -) -from .. import utils +__all__ = [ + "DeepTrackDataDict", + "DeepTrackDataObject", + "DeepTrackNode", +] -citation_midtvet2021quantitative = """ +CITATION_MIDTVEDT2021QUANTITATIVE = """ @article{Midtvet2021Quantitative, author = {Midtvedt, Benjamin and Helgadottir, Saga and Argun, Aykut and Pineda, Jesús and Midtvedt, Daniel and Volpe, Giovanni}, @@ -101,29 +107,31 @@ class DeepTrackDataObject: Attributes ---------- - data : Any + data: Any The stored data. Default is `None`. - valid : bool + valid: bool A flag indicating whether the stored data is valid. Default is `False`. Methods ------- - store(data : Any) -> None - Stores data in the container and marks it as valid. - current_value() -> Any - Returns the currently stored data. - is_valid() -> bool - Returns whether the stored data is valid. - invalidate() -> None - Marks the data as invalid. - validate() -> None - Marks the data as valid. + `store(data: Any) -> None` + Store data in the container and mark it as valid. + `current_value() -> Any` + Return the currently stored data. + `is_valid() -> bool` + Return whether the stored data is valid. + `invalidate() -> None` + Mark the data as invalid. + `validate() -> None` + Mark the data as valid. Example ------- + >>> import deeptrack as dt + Create a `DeepTrackDataObject`: - >>> data_obj = core.DeepTrackDataObject() + >>> data_obj = dt.DeepTrackDataObject() Store a value in this container: @@ -142,7 +150,7 @@ class DeepTrackDataObject: >>> print(data_obj.is_valid()) False - Validate the data again to restore its status: + Validate the data again to restore its valid status: >>> data_obj.validate() >>> print(data_obj.is_valid()) @@ -150,47 +158,49 @@ class DeepTrackDataObject: """ - # Attributes. data: Any valid: bool - def __init__(self): + def __init__(self: DeepTrackDataObject): """Initialize the container without data. - The `data` and `valid` attributes are set to their default values - `None` and `False`. - + It sets the `data` and `valid` attributes are set to their default + values `None` and `False`. + """ self.data = None self.valid = False - def store(self, data: Any) -> None: + def store( + self: DeepTrackDataObject, + data: Any, + ) -> None: """Store data and mark it as valid. Parameters ---------- - data : Any + data: Any The data to be stored in the container. - + """ self.data = data self.valid = True - def current_value(self) -> Any: + def current_value(self: DeepTrackDataObject) -> Any: """Retrieve the stored data. Returns ------- Any The data stored in the container. - + """ return self.data - def is_valid(self) -> bool: + def is_valid(self: DeepTrackDataObject) -> bool: """Return whether the stored data is valid. Returns @@ -202,12 +212,12 @@ def is_valid(self) -> bool: return self.valid - def invalidate(self) -> None: + def invalidate(self: DeepTrackDataObject) -> None: """Mark the stored data as invalid.""" self.valid = False - def validate(self) -> None: + def validate(self: DeepTrackDataObject) -> None: """Mark the stored data as valid.""" self.valid = True @@ -625,7 +635,7 @@ class DeepTrackNode: _all_children: Set['DeepTrackNode'] # Citations associated with DeepTrack2. - _citations: List[str] = [citation_midtvet2021quantitative] + _citations: List[str] = [CITATION_MIDTVEDT2021QUANTITATIVE] @property def action(self) -> Callable[..., Any]: @@ -653,7 +663,7 @@ def action(self, value: Callable[..., Any]) -> None: """ self._action = value - self._accepts_ID = "_ID" in utils.get_kwarg_names(value) + self._accepts_ID = "_ID" in get_kwarg_names(value) def __init__( self, @@ -688,7 +698,7 @@ def __init__( self.action = lambda: action # Check if action accepts `_ID`. - self._accepts_ID = "_ID" in utils.get_kwarg_names(self.action) + self._accepts_ID = "_ID" in get_kwarg_names(self.action) # Call super init in case of multiple inheritance. super().__init__(**kwargs) diff --git a/deeptrack/tests/backend/__init__.py b/deeptrack/tests/backend/__init__.py index ee4a54a96..8b1378917 100644 --- a/deeptrack/tests/backend/__init__.py +++ b/deeptrack/tests/backend/__init__.py @@ -1,3 +1 @@ -#from .test_core import * -#from .test_mie import * -#from .test_polynomials import * + diff --git a/deeptrack/tests/backend/test__config.py b/deeptrack/tests/backend/test__config.py index 95eabdaa0..0d89fd065 100644 --- a/deeptrack/tests/backend/test__config.py +++ b/deeptrack/tests/backend/test__config.py @@ -23,6 +23,20 @@ def tearDown(self): _config.config.set_backend(self.original_backend) _config.config.set_device(self.original_device) + def test___all__(self): + from deeptrack import ( + config, + OPENCV_AVAILABLE, + TORCH_AVAILABLE, + xp, + ) + from deeptrack.backend import ( + config, + OPENCV_AVAILABLE, + TORCH_AVAILABLE, + xp, + ) + def test_TORCH_AVAILABLE(self): try: import torch diff --git a/deeptrack/tests/backend/test_core.py b/deeptrack/tests/backend/test_core.py index 8ba66a779..c8eb6d432 100644 --- a/deeptrack/tests/backend/test_core.py +++ b/deeptrack/tests/backend/test_core.py @@ -13,6 +13,18 @@ class TestCore(unittest.TestCase): + def test___all__(self): + from deeptrack import ( + DeepTrackDataDict, + DeepTrackDataObject, + DeepTrackNode, + ) + from deeptrack.backend import ( + DeepTrackDataDict, + DeepTrackDataObject, + DeepTrackNode, + ) + def test_DeepTrackDataObject(self): dataobj = core.DeepTrackDataObject() @@ -31,7 +43,6 @@ def test_DeepTrackDataObject(self): self.assertEqual(dataobj.current_value(), 1) self.assertEqual(dataobj.is_valid(), True) - def test_DeepTrackDataDict(self): dataset = core.DeepTrackDataDict() @@ -93,7 +104,6 @@ def test_DeepTrackDataDict(self): self.assertIn(key, {(0,), (1,)}) self.assertIsInstance(value, core.DeepTrackDataObject) - def test_DeepTrackNode_basics(self): node = core.DeepTrackNode(action=lambda: 42) @@ -116,7 +126,6 @@ def test_DeepTrackNode_basics(self): self.assertEqual(node(), 42) # Value is calculated and stored. self.assertTrue(node.is_valid()) - def test_DeepTrackNode_dependencies(self): parent = core.DeepTrackNode(action=lambda: 10) child = core.DeepTrackNode(action=lambda _ID=None: parent() * 2) @@ -147,7 +156,6 @@ def test_DeepTrackNode_dependencies(self): self.assertTrue(parent.is_valid()) self.assertTrue(child.is_valid()) - def test_DeepTrackNode_nested_dependencies(self): parent = core.DeepTrackNode(action=lambda: 5) middle = core.DeepTrackNode(action=lambda: parent() + 5) @@ -165,7 +173,6 @@ def test_DeepTrackNode_nested_dependencies(self): self.assertFalse(middle.is_valid()) self.assertFalse(child.is_valid()) - def test_DeepTrackNode_op_overloading(self): node1 = core.DeepTrackNode(action=lambda: 5) node2 = core.DeepTrackNode(action=lambda: 10) @@ -182,12 +189,10 @@ def test_DeepTrackNode_op_overloading(self): div_node = node2 / node1 self.assertEqual(div_node(), 2) - def test_DeepTrackNode_citations(self): node = core.DeepTrackNode(action=lambda: 42) citations = node.get_citations() - self.assertIn(core.citation_midtvet2021quantitative, citations) - + self.assertIn(core.CITATION_MIDTVEDT2021QUANTITATIVE, citations) def test_DeepTrackNode_single_id(self): # Test a single _ID on a simple parent-child relationship. @@ -205,7 +210,6 @@ def test_DeepTrackNode_single_id(self): self.assertEqual(child(_ID=(id,)), value * 2) self.assertEqual(parent.previous((id,)), value) - def test_DeepTrackNode_nested_ids(self): # Test nested IDs for parent-child relationships. @@ -232,7 +236,6 @@ def test_DeepTrackNode_nested_ids(self): child_value_1_1 = child(_ID=(1, 1)) # Uses parent(_ID=(1,)). self.assertEqual(child_value_1_1, 10) - def test_DeepTrackNode_replicated_behavior(self): # Test replicated behavior where IDs expand. @@ -246,7 +249,6 @@ def test_DeepTrackNode_replicated_behavior(self): cluster_value = cluster() self.assertEqual(cluster_value, 3) - def test_DeepTrackNode_parent_id_inheritance(self): # Children with IDs matching than parents. @@ -280,7 +282,6 @@ def test_DeepTrackNode_parent_id_inheritance(self): self.assertEqual(child_deeper(_ID=(1, 1)), 10) self.assertEqual(child_deeper(_ID=(1, 2)), 10) - def test_DeepTrackNode_invalidation_and_ids(self): # Test that invalidating a parent affects specific IDs of children. @@ -306,7 +307,6 @@ def test_DeepTrackNode_invalidation_and_ids(self): self.assertFalse(child.is_valid((1, 0))) self.assertFalse(child.is_valid((1, 1))) - def test_DeepTrackNode_dependency_graph_with_ids(self): # Test a multi-level dependency graph with nested IDs. diff --git a/tutorials/3-advanced-topics/DTAT399F_backend._config.ipynb b/tutorials/3-advanced-topics/DTAT399F_backend._config.ipynb index a7b33c7e2..4dea5ead9 100644 --- a/tutorials/3-advanced-topics/DTAT399F_backend._config.ipynb +++ b/tutorials/3-advanced-topics/DTAT399F_backend._config.ipynb @@ -78,7 +78,7 @@ "metadata": {}, "outputs": [], "source": [ - "from deeptrack.backend._config import config" + "from deeptrack.backend import config" ] }, { @@ -375,7 +375,7 @@ "metadata": {}, "outputs": [], "source": [ - "from deeptrack.backend._config import xp" + "from deeptrack.backend import xp" ] }, { @@ -462,7 +462,7 @@ } ], "source": [ - "from deeptrack.backend._config import TORCH_AVAILABLE\n", + "from deeptrack.backend import TORCH_AVAILABLE\n", "\n", "print(TORCH_AVAILABLE)" ] @@ -488,7 +488,7 @@ } ], "source": [ - "from deeptrack.backend._config import OPENCV_AVAILABLE\n", + "from deeptrack.backend import OPENCV_AVAILABLE\n", "\n", "print(OPENCV_AVAILABLE)" ] From abdfc82d5e1772c6c38216b9b34f1637d646cdc6 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 13:35:54 +0200 Subject: [PATCH 02/54] Update _config.py --- deeptrack/backend/_config.py | 245 +++++++++++++++++++++++++---------- 1 file changed, 173 insertions(+), 72 deletions(-) diff --git a/deeptrack/backend/_config.py b/deeptrack/backend/_config.py index d789389e1..f0df8c3d0 100644 --- a/deeptrack/backend/_config.py +++ b/deeptrack/backend/_config.py @@ -69,41 +69,53 @@ Check the default backend and device: ->>> print(config.get_backend()) # Output: 'numpy' ->>> print(config.get_device()) # Output: 'cpu' +>>> print(config.get_backend()) +'numpy' + +>>> print(config.get_device()) +'cpu' Use the xp proxy to create a NumPy array: >>> array = xp.arange(5) ->>> print(type(array)) # Output: +>>> print(type(array)) + Switch to the PyTorch backend and use GPU: >>> config.set_backend_torch() +>>> print(config.get_backend()) +'torch' + >>> config.set_device("cuda") ->>> print(config.get_backend()) # Output: 'torch' ->>> print(config.get_device()) # Output: 'cuda' +>>> print(config.get_device()) +'cuda' Create a tensor using the xp proxy: >>> tensor = xp.arange(3) ->>> print(type(tensor)) # Output: +>>> print(type(tensor)) + Temporarily switch backends within a context manager: ->>> print(config.get_backend()) # Output: 'torch' +>>> print(config.get_backend()) +'torch' >>> with config.with_backend("numpy"): -... print(config.get_backend()) # Output: 'numpy' +... print(config.get_backend()) +'numpy' ->>> print(config.get_backend()) # Output: 'torch' +>>> print(config.get_backend()) +'torch' Use PyTorch-specific device objects if desired: >>> import torch >>> config.set_device(torch.device("cuda:0")) ->>> print(config.get_device()) # Output: device(type='cuda', index=0) +>>> print(config.get_device()) +device(type='cuda', index=0) Check PyTorch availability: @@ -208,14 +220,18 @@ class _Proxy(types.ModuleType): Create a proxy instance and set the backend to NumPy: >>> from array_api_compat import numpy as apc_np + >>> >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) Use the proxy to create an array (calls NumPy under the hood): >>> array = xp.arange(5) - >>> print(array) # Output: [0 1 2 3 4] - >>> print(type(array)) # Output: + >>> print(array) + Output: [0 1 2 3 4] + + >>> print(type(array)) + You can use any function or attribute provided by the backend: @@ -223,33 +239,54 @@ class _Proxy(types.ModuleType): Query dtypes in a backend-agnostic way: - >>> print(xp.get_float_dtype()) # Output: float64 - >>> print(xp.get_int_dtype()) # Output: int64 - >>> print(xp.get_complex_dtype()) # Output: complex128 - >>> print(xp.get_bool_dtype()) # Output: bool + >>> print(xp.get_float_dtype()) + float64 + + >>> print(xp.get_int_dtype()) + int64 + + >>> print(xp.get_complex_dtype()) + complex128 + + >>> print(xp.get_bool_dtype()) + bool Switch to the PyTorch backend: >>> from array_api_compat import torch as apc_torch + >>> >>> xp = _Proxy("torch") >>> xp.set_backend(apc_torch) Now the proxy uses PyTorch: + >>> array = xp.arange(5) - >>> print(array) # Output: tensor([0, 1, 2, 3, 4]) - >>> print(type(array)) # Output: + >>> print(array) + tensor([0, 1, 2, 3, 4]) + + >>> print(type(array)) + The dtype helpers return PyTorch-specific types: - >>> print(xp.get_float_dtype()) # Output: torch.float32 - >>> print(xp.get_int_dtype()) # Output: torch.int64 - >>> print(xp.get_complex_dtype()) # Output: torch.complex64 - >>> print(xp.get_bool_dtype()) # Output: torch.bool + >>> print(xp.get_float_dtype()) + torch.float32 + + >>> print(xp.get_int_dtype()) + torch.int64 + + >>> print(xp.get_complex_dtype()) + torch.complex64 + + >>> print(xp.get_bool_dtype()) + torch.bool You can switch backends as often as needed.: + >>> xp.set_backend(apc_np) >>> array = xp.arange(3) - >>> print(type(array)) # Output: + >>> print(type(array)) + """ @@ -292,7 +329,8 @@ def set_backend( >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) >>> array = xp.arange(5) - >>> print(type(array)) # Output: + >>> print(type(array)) + Now switch to a PyTorch backend: @@ -301,7 +339,8 @@ def set_backend( >>> xp = _Proxy("torch") >>> xp.set_backend(apc_torch) >>> array = xp.arange(5) - >>> print(type(array)) # Output: + >>> print(type(array)) + """ @@ -333,26 +372,32 @@ def get_float_dtype( Create a proxy instance and set the backend to NumPy: >>> from array_api_compat import numpy as apc_np + >>> >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) >>> dtype = xp.get_float_dtype() - >>> print(dtype) # Output: float64 + >>> print(dtype) + float64 >>> dtype = xp.get_float_dtype("float32") - >>> print(dtype) # Output: float32 + >>> print(dtype) + float32 Now switch to a PyTorch backend: >>> from array_api_compat import torch as apc_torch + >>> >>> xp = _Proxy("torch") >>> xp.set_backend(apc_torch) >>> dtype = xp.get_float_dtype() - >>> print(dtype) # Output: torch.float32 + >>> print(dtype) + torch.float32 >>> dtype = xp.get_float_dtype("float32") - >>> print(dtype) # Output: torch.float32 + >>> print(dtype) + torch.float32 """ @@ -385,14 +430,17 @@ def get_int_dtype( Create a proxy instance and set the backend to NumPy: >>> from array_api_compat import numpy as apc_np + >>> >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) >>> dtype = xp.get_int_dtype() - >>> print(dtype) # Output: int64 + >>> print(dtype) + int64 >>> dtype = xp.get_int_dtype("int32") - >>> print(dtype) # Output: int32 + >>> print(dtype) + int32 Now switch to a PyTorch backend: @@ -401,10 +449,12 @@ def get_int_dtype( >>> xp.set_backend(apc_torch) >>> dtype = xp.get_int_dtype() - >>> print(dtype) # Output: torch.int64 + >>> print(dtype) + torch.int64 >>> dtype = xp.get_int_dtype("int32") - >>> print(dtype) # Output: torch.int32 + >>> print(dtype) + torch.int32 """ @@ -437,26 +487,32 @@ def get_complex_dtype( Create a proxy instance and set the backend to NumPy: >>> from array_api_compat import numpy as apc_np + >>> >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) >>> dtype = xp.get_complex_dtype() - >>> print(dtype) # Output: complex128 + >>> print(dtype) + complex128 >>> dtype = xp.get_complex_dtype("complex64") - >>> print(dtype) # Output: complex64 + >>> print(dtype) + complex64 Now switch to a PyTorch backend: >>> from array_api_compat import torch as apc_torch + >>> >>> xp = _Proxy("torch") >>> xp.set_backend(apc_torch) >>> dtype = xp.get_complex_dtype() - >>> print(dtype) # Output: torch.complex64 + >>> print(dtype) + torch.complex64 >>> dtype = xp.get_complex_dtype("complex64") - >>> print(dtype) # Output: torch.complex64 + >>> print(dtype) + torch.complex64 """ @@ -489,26 +545,32 @@ def get_bool_dtype( Create a proxy instance and set the backend to NumPy: >>> from array_api_compat import numpy as apc_np + >>> >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) >>> dtype = xp.get_bool_dtype() - >>> print(dtype) # Output: bool + >>> print(dtype) + bool >>> dtype = xp.get_bool_dtype(dtype="bool") - >>> print(dtype) # Output: bool + >>> print(dtype) + bool Now switch to a PyTorch backend: >>> from array_api_compat import torch as apc_torch + >>> >>> xp = _Proxy("torch") >>> xp.set_backend(apc_torch) >>> dtype = xp.get_bool_dtype() - >>> print(dtype) # Output: torch.bool + >>> print(dtype) + torch.bool >>> dtype = xp.get_bool_dtype(dtype="bool") - >>> print(dtype) # Output: torch.bool + >>> print(dtype) + torch.bool """ @@ -538,17 +600,22 @@ def __getattr__( Access NumPy's arange function transparently through the proxy: >>> from array_api_compat import numpy as apc_np + >>> >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) >>> array = xp.arange(4) - >>> print(array) # Output: [0 1 2 3] + >>> print(array) + [0 1 2 3] Now switch to a PyTorch backend: + >>> from array_api_compat import torch as apc_torch + >>> >>> xp = ._Proxy("torch") >>> xp.set_backend(apc_torch) >>> array = xp.arange(4) - >>> print(array) # Output: tensor([0, 1, 2, 3]) + >>> print(array) + tensor([0, 1, 2, 3]) Analogously, you can access any attribute or function available in the current backend. @@ -568,14 +635,18 @@ def __dir__(self: _Proxy) -> list[str]: Examples -------- List the attributes (functions, constants, etc.) in the NumPy backend: + >>> from array_api_compat import numpy as apc_np + >>> >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) >>> attrs_numpy = dir(xp) >>> print(attrs_numpy) List the attributes in the PyTorch backend: + >>> from array_api_compat import torch as apc_torch + >>> >>> xp = _Proxy("torch") >>> xp.set_backend(apc_torch) >>> attrs_torch = dir(xp) @@ -648,15 +719,21 @@ class Config: >>> from deeptrack.backend import config - >>> print(config.get_backend()) # Output: 'numpy' - >>> print(config.get_device()) # Output: 'cpu' + >>> print(config.get_backend()) + 'numpy' + + >>> print(config.get_device()) + 'cpu' Set the backend to PyTorch and device to GPU: >>> config.set_backend_torch() + >>> print(config.get_backend()) + 'torch' + >>> config.set_device("cuda") - >>> print(config.get_backend()) # Output: 'torch' - >>> print(config.get_device()) # Output: 'cuda' + >>> print(config.get_device()) + 'cuda' Use the xp proxy to create arrays/tensors: @@ -664,29 +741,35 @@ class Config: >>> config.set_backend_numpy() >>> array = xp.arange(5) - >>> print(type(array)) # Output: + >>> print(type(array)) + >>> config.set_backend_torch() >>> tensor = xp.arange(5) - >>> print(type(tensor)) # Output: + >>> print(type(tensor)) + Temporarily switch backend using a context manager: >>> config.set_backend("torch") - >>> print(config.get_backend()) # Output: 'torch' + >>> print(config.get_backend()) + 'torch' >>> with config.with_backend("numpy"): - ... print(config.get_backend()) # Output: 'numpy' + ... print(config.get_backend()) + 'numpy' - >>> print(config.get_backend()) # Output: 'torch' + >>> print(config.get_backend()) + 'torch' Use a torch.device object directly: >>> import torch - + >>> >>> config.set_backend_torch() >>> config.set_device(torch.device("cuda:0")) - >>> print(config.get_device()) # Output: device(type='cuda', index=0) + >>> print(config.get_device()) + device(type='cuda', index=0) """ @@ -731,32 +814,37 @@ def set_device( Set device to CPU (works with both NumPy and PyTorch backends): >>> config.set_device("cpu") - >>> print(config.get_device()) # Output: cpu + >>> print(config.get_device()) + cpu Set device to GPU (requires PyTorch backend): >>> config.set_backend_torch() >>> config.set_device("cuda") - >>> print(config.get_device()) # Output: cuda + >>> print(config.get_device()) + cuda Use a specific CUDA device (PyTorch backend): >>> import torch >>> config.set_backend_torch() >>> config.set_device(torch.device("cuda:0")) - >>> print(config.get_device()) # Output: device(type='cuda', index=0) + >>> print(config.get_device()) + device(type='cuda', index=0) Set device to Apple Silicon GPU (PyTorch backend on Macs): >>> config.set_backend_torch() >>> config.set_device("mps") - >>> print(config.get_device()) # Output: mps + >>> print(config.get_device()) + mps Attempting to set a GPU device with NumPy backend (should be avoided): >>> config.set_backend_numpy() >>> config.set_device("cuda") - >>> print(config.get_device()) # Output: cuda + >>> print(config.get_device()) + cuda Computation will still run on CPU, since NumPy does not support GPU. @@ -800,13 +888,15 @@ def set_backend_numpy(self: Config) -> None: Set the backend to NumPy: >>> config.set_backend_numpy() - >>> print(config.get_backend()) # Output: 'numpy' + >>> print(config.get_backend()) + 'numpy' NumPy backend enables use of standard NumPy arrays via the xp proxy: >>> from deeptrack.backend import xp >>> array = xp.arange(5) - >>> print(type(array)) # Output: + >>> print(type(array)) + """ @@ -824,14 +914,16 @@ def set_backend_torch(self: Config) -> None: Set the backend to PyTorch: >>> config.set_backend_torch() - >>> print(config.get_backend()) # Output: 'torch' + >>> print(config.get_backend()) + 'torch' PyTorch backend enables use of PyTorch tensors via the xp proxy: >>> from deeptrack.backend import xp >>> tensor = xp.arange(5) - >>> print(type(tensor)) # Output: + >>> print(type(tensor)) + """ @@ -857,12 +949,14 @@ def set_backend( Set the backend to NumPy: >>> config.set_backend("numpy") - >>> print(config.get_backend()) # Output: 'numpy' + >>> print(config.get_backend()) + 'numpy' Set the backend to PyTorch: >>> config.set_backend("torch") - >>> print(config.get_backend()) # Output: 'torch' + >>> print(config.get_backend()) + 'torch' Switch between backends as needed in your workflow using the xp proxy: @@ -870,11 +964,13 @@ def set_backend( >>> config.set_backend("numpy") >>> array = xp.arange(4) - >>> print(type(array)) # Output: + >>> print(type(array)) + >>> config.set_backend("torch") >>> tensor = xp.arange(4) - >>> print(type(tensor)) # Output: + >>> print(type(tensor)) + """ @@ -937,12 +1033,15 @@ def with_backend( Temporarily switch to the NumPy backend for a block of code: >>> config.set_backend("torch") - >>> print(config.get_backend()) # Output: 'torch' + >>> print(config.get_backend()) + 'torch' >>> with config.with_backend("numpy"): - ... print(config.get_backend()) # Output: 'numpy' + ... print(config.get_backend()) + 'numpy' - >>> print(config.get_backend()) # Output: 'torch' + >>> print(config.get_backend()) + 'torch' Temporarily switch to the PyTorch backend inside a function: @@ -955,9 +1054,11 @@ def with_backend( ... return xp.arange(3) >>> tensor = do_torch_operation() - >>> print(type(tensor)) # Output: + >>> print(type(tensor)) + - >>> print(config.get_backend()) # Output: 'numpy' + >>> print(config.get_backend()) + 'numpy' """ From b681268e416ca6ae0a2a539d0344bc22a033e2a9 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 14:03:52 +0200 Subject: [PATCH 03/54] Update core.py --- deeptrack/backend/core.py | 229 ++++++++++++++++++++------------------ 1 file changed, 121 insertions(+), 108 deletions(-) diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index 9b71340af..6df7a0919 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -39,7 +39,7 @@ - `DeepTrackDataDict`: Dictionary to store multiple data with validation. A data container to store multiple data objects (`DeepTrackDataObject`) - indexed by unique access IDs (consisting of tuples of integers), enabling + indexed by unique access _IDs (consisting of tuples of integers), enabling nested data storage. Computation Nodes: @@ -224,51 +224,52 @@ def validate(self: DeepTrackDataObject) -> None: class DeepTrackDataDict: - """Stores multiple data objects indexed by a tuple of integers (ID). + """Stores multiple data objects indexed by tuples of integers (_ID). `DeepTrackDataDict` can store multiple `DeepTrackDataObject` instances, - each associated with a unique tuple of integers (its ID). This is + each associated with a unique tuple of integers (its _ID). This is particularly useful to handle sequences of data or nested structures. - The default ID is an empty tuple, `()`. Once the first entry is created, - all IDs must match the established key length: - - - If an ID longer than the set length is requested, it is trimmed. - - If an ID shorter than the set length is requested, a dictionary slice + The default _ID is an empty tuple, `()`. Once the first entry is created, + all _IDs must match the established key length: + - If an _ID longer than the set length is requested, it is trimmed. + - If an _ID shorter than the set length is requested, a dictionary slice containing all matching entries is returned. Attributes ---------- - keylength : int or None - The length of the IDs currently stored. Set when the first entry is - created. If `None`, no entries have been created yet, and any ID length - is valid. - dict : Dict[Tuple[int, ...], DeepTrackDataObject] - A dictionary mapping tuples of integers (IDs) to `DeepTrackDataObject` - instances. + keylength: int or None + The length of the _IDs currently stored. Set when the first entry is + created. If `None`, no entries have been created yet, and any _ID + length is valid. + dict: dict[tuple[int, ...], DeepTrackDataObject] + A dictionary mapping tuples of integers (_IDs) to + `DeepTrackDataObject` instances. Methods ------- - invalidate() -> None - Marks all stored data objects as invalid. - validate() -> None - Marks all stored data objects as valid. - valid_index(_ID : Tuple[int, ...]) -> bool - Checks if the given ID is valid for the current configuration. - create_index(_ID : Tuple[int, ...] = ()) -> None - Creates an entry for the given ID if it does not exist. - __getitem__(_ID : Tuple[int, ...]) -> DeepTrackDataObject or Dict[Tuple[int, ...], DeepTrackDataObject] - Retrieves data associated with the ID. Can return a + `invalidate() -> None` + Mark all stored data objects as invalid. + `validate() -> None` + Mark all stored data objects as valid. + `valid_index(_ID : tuple[int, ...]) -> bool` + Check if the given _ID is valid for the current configuration. + `create_index(_ID : tuple[int, ...] = ()) -> None` + Create an entry for the given _ID if it does not exist. + `__getitem__(_ID : tuple[int, ...]) -> DeepTrackDataObject or dict[tuple[int, ...], DeepTrackDataObject]` + Retrieve data associated with the _ID. Can return a `DeepTrackDataObject` or a dict of matching entries if `_ID` is shorter than `keylength`. - __contains__(_ID : Tuple[int, ...]) -> bool - Checks if the given ID exists in the dictionary. + `__contains__(_ID : tuple[int, ...]) -> bool` + Check whether the given _ID exists in the dictionary. Example ------- + >>> import deeptrack as dt + Create a structure to store multiple, indexed instances of data: - >>> data_dict = DeepTrackDataDict() + >>> data_dict = dt.DeepTrackDataDict() Create the entries: @@ -277,14 +278,14 @@ class DeepTrackDataDict: >>> data_dict.create_index((1, 0)) >>> data_dict.create_index((1, 1)) - Store the values associated with each ID: + Store the values associated with each _ID: >>> data_dict[(0, 0)].store("Data at (0, 0)") >>> data_dict[(0, 1)].store("Data at (0, 1)") >>> data_dict[(1, 0)].store("Data at (1, 0)") >>> data_dict[(1, 1)].store("Data at (1, 1)") - Retrieve values based on their IDs: + Retrieve values based on their _IDs: >>> print(data_dict[(0, 0)].current_value()) Data at (0, 0) @@ -292,24 +293,23 @@ class DeepTrackDataDict: >>> print(data_dict[(1, 1)].current_value()) Data at (1, 1) - If requesting a shorter ID, it returns all matching nested entries: + If requesting a shorter _ID, it returns all matching nested entries: >>> print(data_dict[(0,)]) { - (0, 0): , - (0, 1): , + (0, 0): , + (0, 1): , } """ - # Attributes. - keylength: Optional[int] - dict: Dict[Tuple[int, ...], DeepTrackDataObject] + keylength: int + dict: dict[tuple[int, ...], DeepTrackDataObject] - def __init__(self): + def __init__(self: DeepTrackDataDict): """Initialize the data dictionary. - Initializes `keylength` to `None` and `dict` to an empty dictionary, + It initializes `keylength` to `None` and `dict` to an empty dictionary, indicating no data objects are currently stored. """ @@ -317,44 +317,50 @@ def __init__(self): self.keylength = None self.dict = {} - def invalidate(self) -> None: + def invalidate(self: DeepTrackDataDict) -> None: """Mark all stored data objects as invalid. - Calls `invalidate()` on every `DeepTrackDataObject` in the dictionary. - + It calls `invalidate()` on every `DeepTrackDataObject` in the + dictionary. + """ for dataobject in self.dict.values(): dataobject.invalidate() - def validate(self) -> None: + def validate(self: DeepTrackDataDict) -> None: """Mark all stored data objects as valid. - This method calls `validate()` on every `DeepTrackDataObject` in the - dictionary. - + It calls `validate()` on every `DeepTrackDataObject` in the dictionary. + """ for dataobject in self.dict.values(): dataobject.validate() - def valid_index(self, _ID: Tuple[int, ...]) -> bool: - """Check if a given ID is valid for this data dictionary. + def valid_index( + self: DeepTrackDataDict, + _ID: tuple[int, ...], + ) -> bool: + """Check if a given _ID is valid for this data dictionary. If `keylength` is `None`, any tuple `_ID` is considered valid since no - entries have been created yet. If `_ID` already exists in `dict`, it is - automatically valid. Otherwise, `_ID` must have the same length as - `keylength` to be considered valid. + entries have been created yet. + + If `_ID` already exists in `dict`, it is automatically valid. + + Otherwise, `_ID` must have the same length as `keylength` to be + considered valid. Parameters ---------- - _ID : Tuple[int, ...] + _ID: tuple[int, ...] The index to check, consisting of a tuple of integers. Returns ------- bool - `True` if the ID is valid given the current configuration, `False` + `True` if the _ID is valid given the current configuration, `False` otherwise. Raises @@ -364,7 +370,7 @@ def valid_index(self, _ID: Tuple[int, ...]) -> bool: """ - # Ensure `_ID` is a tuple of integers. + # Ensure _ID is a tuple of integers. assert isinstance(_ID, tuple), ( f"Data index {_ID} is not a tuple. Got: {type(_ID).__name__}." ) @@ -381,23 +387,28 @@ def valid_index(self, _ID: Tuple[int, ...]) -> bool: if _ID in self.dict: return True - # Otherwise, the ID length must match the established keylength. + # Otherwise, the _ID length must match the established keylength + # for _ID to be valid. return len(_ID) == self.keylength - def create_index(self, _ID: Tuple[int, ...] = ()) -> None: - """Create a new data entry for the given ID if not already existing. + def create_index( + self: DeepTrackDataDict, + _ID: tuple[int, ...] = (), + ) -> None: + """Create a new data entry for the given _ID if not already existing. - Each newly created index is associated with a new - `DeepTrackDataObject`. If `_ID` is already in `dict`, no new entry is - created. + Each newly created index is associated with a new + `DeepTrackDataObject`. + + If `_ID` is already in `dict`, no new entry is created. If `keylength` is `None`, it is set to the length of `_ID`. Once - established, all subsequently created IDs must have this same length. + established, all subsequently created _IDs must have this same length. Parameters ---------- - _ID : Tuple[int, ...], optional - A tuple of integers representing the ID for the data entry. + _ID: tuple[int, ...], optional + A tuple of integers representing the _ID for the data entry. Default is `()`, which represents a root-level data entry with no nesting. @@ -409,8 +420,8 @@ def create_index(self, _ID: Tuple[int, ...] = ()) -> None: """ - # Check if the given `_ID` is valid. - # (Also: Ensure `_ID` is a tuple of integers.) + # Check if the given _ID is valid. + # (Also: Ensure _ID is a tuple of integers.) assert self.valid_index(_ID), ( f"{_ID} is not a valid index for current dictionary configuration." ) @@ -419,36 +430,34 @@ def create_index(self, _ID: Tuple[int, ...] = ()) -> None: if _ID in self.dict: return - # Create a new DeepTrackDataObject for this ID. + # Create a new DeepTrackDataObject for this _ID. self.dict[_ID] = DeepTrackDataObject() - # If `keylength` is not set, initialize it with current ID's length. + # If `keylength` is not set, initialize it with current _IDs length. if self.keylength is None: self.keylength = len(_ID) def __getitem__( - self, - _ID: Tuple[int, ...], - ) -> Union[ - DeepTrackDataObject, - Dict[Tuple[int, ...], DeepTrackDataObject] - ]: - """Retrieve data associated with a given ID. + self: DeepTrackDataDict, + _ID: tuple[int, ...], + ) -> DeepTrackDataObject | dict[tuple[int, ...], DeepTrackDataObject]: + """Retrieve data associated with a given _ID. Parameters ---------- - _ID : Tuple[int, ...] - The ID for the requested data. + _ID: Tuple[int, ...] + The _ID for the requested data. Returns ------- - DeepTrackDataObject or Dict[Tuple[int, ...], DeepTrackDataObject] - If `_ID` matches `keylength`, returns the corresponding + DeepTrackDataObject or Dict[tuple[int, ...], DeepTrackDataObject] + If `_ID` matches `keylength`, it returns the corresponding `DeepTrackDataObject`. If `_ID` is longer than `keylength`, the request is trimmed to - match `keylength`. - If `_ID` is shorter than `keylength`, returns a dict of all entries - whose IDs match the given `_ID` prefix. + match `keylength` and it returns the corresponding + `DeepTrackDataObject`. + If `_ID` is shorter than `keylength`, it returns a dict of all + entries whose _IDs match the given `_ID` prefix. Raises ------ @@ -456,7 +465,7 @@ def __getitem__( If `_ID` is not a tuple of integers. KeyError If the dictionary is empty (`keylength` is `None`). - + """ # Ensure `_ID` is a tuple of integers. @@ -471,29 +480,33 @@ def __getitem__( if self.keylength is None: raise KeyError("Attempting to index an empty dict.") - # If ID matches keylength, returns corresponding DeepTrackDataObject. + # If _ID matches keylength, return corresponding DeepTrackDataObject. if len(_ID) == self.keylength: return self.dict[_ID] - # If ID longer than keylength, trim the requested ID. + # If _ID longer than keylength, trim the requested _ID + # and return corresponding DeepTrackDataObject. if len(_ID) > self.keylength: return self[_ID[: self.keylength]] - # If ID longer than keylength, return a slice of all matching items. + # If _ID shorter than keylength, return a slice of all matching items. return {k: v for k, v in self.dict.items() if k[: len(_ID)] == _ID} - def __contains__(self, _ID: Tuple[int, ...]) -> bool: - """Check if a given ID exists in the dictionary. + def __contains__( + self: DeepTrackDataDict, + _ID: tuple[int, ...], + ) -> bool: + """Check if a given _ID exists in the dictionary. Parameters ---------- - _ID : Tuple[int, ...] - The ID to check. + _ID : tuple[int, ...] + The _ID to check. Returns ------- bool - `True` if the ID exists, `False` otherwise. + `True` if the _ID exists, `False` otherwise. Raises ------ @@ -502,7 +515,7 @@ def __contains__(self, _ID: Tuple[int, ...]) -> bool: """ - # Ensure `_ID` is a tuple of integers. + # Ensure _ID is a tuple of integers. assert isinstance(_ID, tuple), ( f"Data index {_ID} is not a tuple. Got: {type(_ID).__name__}." ) @@ -533,7 +546,7 @@ class DeepTrackNode: _action : Callable The function or lambda-function to compute the node value. _accepts_ID : bool - Whether `action` accepts an input ID. + Whether `action` accepts an input _ID. _all_children : Set[DeepTrackNode] All nodes in the subtree rooted at the node, including the node itself. _citations : List[str] @@ -597,7 +610,7 @@ class DeepTrackNode: >>> parent.add_child(child) - Store values in the parent node for specific IDs: + Store values in the parent node for specific _IDs: >>> parent.store(15, _ID=(0,)) >>> parent.store(20, _ID=(1,)) @@ -609,7 +622,7 @@ class DeepTrackNode: >>> print(child_value_0, child_value_1) 30 40 - Invalidate the parent data for a specific ID: + Invalidate the parent data for a specific _ID: >>> parent.invalidate((0,)) >>> print(parent.is_valid((0,))) @@ -789,12 +802,12 @@ def store(self, data: Any, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': return self def is_valid(self, _ID: Tuple[int, ...] = ()) -> bool: - """Check if data for the given ID is valid. + """Check if data for the given _ID is valid. Parameters ---------- _ID : Tuple[int, ...], optional - The ID to check validity for. + The _ID to check validity for. Returns ------- @@ -809,12 +822,12 @@ def is_valid(self, _ID: Tuple[int, ...] = ()) -> bool: return False def valid_index(self, _ID: Tuple[int, ...]) -> bool: - """Check if ID is a valid index for this node’s data. + """Check if _ID is a valid index for this node’s data. Parameters ---------- _ID : Tuple[int, ...] - The ID to validate. + The _ID to validate. Returns ------- @@ -831,7 +844,7 @@ def invalidate(self, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': Parameters ---------- _ID : Tuple[int, ...], optional - The ID to invalidate. Default is empty tuple, indicating + The _ID to invalidate. Default is empty tuple, indicating potentially the full dataset. Returns @@ -841,7 +854,7 @@ def invalidate(self, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': Note ---- - At the moment, the code to invalidate specific IDs is not implemented, + At the moment, the code to invalidate specific _IDs is not implemented, so the _ID parameter is not effectively used. """ @@ -859,7 +872,7 @@ def validate(self, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': Parameters ---------- _ID : Tuple[int, ...], optional - The ID to validate. Default is empty tuple. + The _ID to validate. Default is empty tuple. Returns ------- @@ -897,7 +910,7 @@ def update(self) -> 'DeepTrackNode': return self def set_value(self, value, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': - """Set a value for this node’s data at ID. + """Set a value for this node’s data at _ID. If the value is different from the currently stored one (or if it is invalid), it will invalidate the old data before storing the new one. @@ -907,7 +920,7 @@ def set_value(self, value, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': value : Any The value to store. _ID : Tuple[int, ...], optional - The ID at which to store the value. + The _ID at which to store the value. Returns ------- @@ -928,12 +941,12 @@ def set_value(self, value, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': return self def previous(self, _ID: Tuple[int, ...] = ()) -> Any: - """Retrieve the previously stored value at ID without recomputing. + """Retrieve the previously stored value at _ID without recomputing. Parameters ---------- _ID : Tuple[int, ...], optional - The ID for which to retrieve the previous value. + The _ID for which to retrieve the previous value. Returns ------- @@ -1074,7 +1087,7 @@ def get_citations(self) -> Set[str]: return citations def __call__(self, _ID: Tuple[int, ...] = ()) -> Any: - """Evaluate this node at ID. + """Evaluate this node at _ID. If the data at `_ID` is valid, it returns the stored value. Otherwise, it calls `action` to compute a new value, stores it, and returns it. @@ -1082,7 +1095,7 @@ def __call__(self, _ID: Tuple[int, ...] = ()) -> Any: Parameters ---------- _ID : Tuple[int, ...], optional - The ID at which to evaluate the node’s action. + The _ID at which to evaluate the node’s action. Returns ------- @@ -1109,12 +1122,12 @@ def __call__(self, _ID: Tuple[int, ...] = ()) -> Any: return self.current_value(_ID) def current_value(self, _ID: Tuple[int, ...] = ()) -> Any: - """Retrieve the currently stored value at ID. + """Retrieve the currently stored value at _ID. Parameters ---------- _ID : Tuple[int, ...], optional - The ID at which to retrieve the current value. + The _ID at which to retrieve the current value. Returns ------- From a737f833f1caa153bf3b0c730aea48503607aed5 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 14:03:57 +0200 Subject: [PATCH 04/54] Update test_core.py --- deeptrack/tests/backend/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeptrack/tests/backend/test_core.py b/deeptrack/tests/backend/test_core.py index c8eb6d432..d357913a9 100644 --- a/deeptrack/tests/backend/test_core.py +++ b/deeptrack/tests/backend/test_core.py @@ -48,7 +48,7 @@ def test_DeepTrackDataDict(self): # Test initial state. self.assertEqual(dataset.keylength, None) - self.assertFalse(dataset.dict) + self.assertFalse(dataset.dict) # Empty dict, {} # Create indices and store data. dataset.create_index((0,)) From 64fbdda695a9034bfc02fe33d919e61ae1ae10e2 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 14:10:42 +0200 Subject: [PATCH 05/54] Update _config.py --- deeptrack/backend/_config.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/deeptrack/backend/_config.py b/deeptrack/backend/_config.py index f0df8c3d0..08169cd52 100644 --- a/deeptrack/backend/_config.py +++ b/deeptrack/backend/_config.py @@ -71,7 +71,6 @@ >>> print(config.get_backend()) 'numpy' - >>> print(config.get_device()) 'cpu' @@ -229,7 +228,6 @@ class _Proxy(types.ModuleType): >>> array = xp.arange(5) >>> print(array) Output: [0 1 2 3 4] - >>> print(type(array)) @@ -241,13 +239,10 @@ class _Proxy(types.ModuleType): >>> print(xp.get_float_dtype()) float64 - >>> print(xp.get_int_dtype()) int64 - >>> print(xp.get_complex_dtype()) complex128 - >>> print(xp.get_bool_dtype()) bool @@ -263,7 +258,6 @@ class _Proxy(types.ModuleType): >>> array = xp.arange(5) >>> print(array) tensor([0, 1, 2, 3, 4]) - >>> print(type(array)) @@ -271,13 +265,10 @@ class _Proxy(types.ModuleType): >>> print(xp.get_float_dtype()) torch.float32 - >>> print(xp.get_int_dtype()) torch.int64 - >>> print(xp.get_complex_dtype()) torch.complex64 - >>> print(xp.get_bool_dtype()) torch.bool @@ -721,7 +712,6 @@ class Config: >>> print(config.get_backend()) 'numpy' - >>> print(config.get_device()) 'cpu' From 99d180d9403064c247e9613deace501857228b56 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 14:16:33 +0200 Subject: [PATCH 06/54] Update core.py --- deeptrack/backend/core.py | 64 ++++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 8 deletions(-) diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index 6df7a0919..083df24bf 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -136,24 +136,24 @@ class DeepTrackDataObject: Store a value in this container: >>> data_obj.store(42) - >>> print(data_obj.current_value()) + >>> data_obj.current_value() 42 Check if the stored data is valid: - >>> print(data_obj.is_valid()) + >>> data_obj.is_valid() True Invalidate the stored data: >>> data_obj.invalidate() - >>> print(data_obj.is_valid()) + >>> data_obj.is_valid() False Validate the data again to restore its valid status: >>> data_obj.validate() - >>> print(data_obj.is_valid()) + >>> data_obj.is_valid() True """ @@ -287,20 +287,68 @@ class DeepTrackDataDict: Retrieve values based on their _IDs: - >>> print(data_dict[(0, 0)].current_value()) + >>> data_dict[(0, 0)].current_value() Data at (0, 0) - >>> print(data_dict[(1, 1)].current_value()) + >>> data_dict[(1, 1)].current_value() Data at (1, 1) If requesting a shorter _ID, it returns all matching nested entries: - >>> print(data_dict[(0,)]) + >>> data_dict[(0,)] { (0, 0): , (0, 1): , } - + + Validate and invalidate all entries at once: + + >>> data_dict.invalidate() + >>> data_dict[(0, 0)].is_valid() + False + >>> data_dict[(1, 1)].is_valid() + False + + >>> data_dict.validate() + >>> data_dict[(0, 0)].is_valid() + True + >>> data_dict[(1, 1)].is_valid() + True + + Invalidate and validate a single entry: + + >>> data_dict[(0, 1)].invalidate() + >>> data_dict[(0, 1)].is_valid() + False + >>> data_dict[(0, 1)].validate() + >>> data_dict[(0, 1)].is_valid() + True + + Check if a given _ID exists: + + >>> (1, 0) in data_dict + True + >>> (2, 2) in data_dict + False + + Iterate over all entries: + + >>> for key, value in data_dict.dict.items(): + ... print(key, value.current_value()) + (0, 0) Data at (0, 0) + (0, 1) Data at (0, 1) + (1, 0) Data at (1, 0) + (1, 1) Data at (1, 1) + + Check if an _ID is valid according to current keylength: + + >>> data_dict.valid_index((0, 1)) + True + >>> data_dict.valid_index((0,)) # Shorter than keylength after creation + False + >>> data_dict.valid_index((2, 2)) # Valid length, even if not created yet + True + """ keylength: int From 2f4d2510add38b187c6a8bbd7a803199d93e9bd1 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 14:50:13 +0200 Subject: [PATCH 07/54] Update _config.py --- deeptrack/backend/_config.py | 280 ++++++++++++++++++----------------- 1 file changed, 147 insertions(+), 133 deletions(-) diff --git a/deeptrack/backend/_config.py b/deeptrack/backend/_config.py index 08169cd52..4a578f149 100644 --- a/deeptrack/backend/_config.py +++ b/deeptrack/backend/_config.py @@ -69,63 +69,64 @@ Check the default backend and device: ->>> print(config.get_backend()) +>>> config.get_backend() 'numpy' ->>> print(config.get_device()) + +>>> config.get_device() 'cpu' Use the xp proxy to create a NumPy array: >>> array = xp.arange(5) ->>> print(type(array)) - +>>> type(array) +numpy.ndarray Switch to the PyTorch backend and use GPU: >>> config.set_backend_torch() ->>> print(config.get_backend()) +>>> config.get_backend() 'torch' >>> config.set_device("cuda") ->>> print(config.get_device()) +>>> config.get_device() 'cuda' Create a tensor using the xp proxy: >>> tensor = xp.arange(3) ->>> print(type(tensor)) - +>>> type(tensor) +torch.Tensor Temporarily switch backends within a context manager: ->>> print(config.get_backend()) +>>> config.get_backend() 'torch' >>> with config.with_backend("numpy"): ... print(config.get_backend()) -'numpy' +numpy ->>> print(config.get_backend()) +>>> config.get_backend() 'torch' Use PyTorch-specific device objects if desired: >>> import torch - +>>> >>> config.set_device(torch.device("cuda:0")) ->>> print(config.get_device()) +>>> config.get_device() device(type='cuda', index=0) Check PyTorch availability: >>> from deeptrack.backend import TORCH_AVAILABLE - +>>> >>> print(TORCH_AVAILABLE) Check OpenCV availability: >>> from deeptrack.backend import OPENCV_AVAILABLE - +>>> >>> print(OPENCV_AVAILABLE) """ @@ -216,6 +217,8 @@ class _Proxy(types.ModuleType): Examples -------- + >>> from deeptrack.backend._config import _Proxy + Create a proxy instance and set the backend to NumPy: >>> from array_api_compat import numpy as apc_np @@ -226,10 +229,11 @@ class _Proxy(types.ModuleType): Use the proxy to create an array (calls NumPy under the hood): >>> array = xp.arange(5) - >>> print(array) - Output: [0 1 2 3 4] - >>> print(type(array)) - + >>> array, type(array) + array([0, 1, 2, 3, 4]) + + >>> type(array) + numpy.ndarray You can use any function or attribute provided by the backend: @@ -237,14 +241,18 @@ class _Proxy(types.ModuleType): Query dtypes in a backend-agnostic way: - >>> print(xp.get_float_dtype()) - float64 - >>> print(xp.get_int_dtype()) - int64 - >>> print(xp.get_complex_dtype()) - complex128 - >>> print(xp.get_bool_dtype()) - bool + >>> xp.get_float_dtype() + dtype('float64') + + >>> xp.get_int_dtype() + dtype('int64') + + >>> xp.get_complex_dtype() + dtype('complex128') + + + >>> xp.get_bool_dtype() + dtype('bool') Switch to the PyTorch backend: @@ -255,29 +263,33 @@ class _Proxy(types.ModuleType): Now the proxy uses PyTorch: - >>> array = xp.arange(5) - >>> print(array) + >>> tensor = xp.arange(5) + >>> tensor tensor([0, 1, 2, 3, 4]) - >>> print(type(array)) - + + >>> type(tensor) + torch.Tensor The dtype helpers return PyTorch-specific types: - >>> print(xp.get_float_dtype()) + >>> xp.get_float_dtype() torch.float32 - >>> print(xp.get_int_dtype()) + + >>> xp.get_int_dtype() torch.int64 - >>> print(xp.get_complex_dtype()) + + >>> xp.get_complex_dtype() torch.complex64 - >>> print(xp.get_bool_dtype()) + + >>> xp.get_bool_dtype() torch.bool You can switch backends as often as needed.: >>> xp.set_backend(apc_np) >>> array = xp.arange(3) - >>> print(type(array)) - + >>> type(array) + numpy.ndarray """ @@ -313,6 +325,8 @@ def set_backend( Examples -------- + >>> from deeptrack.backend._config import _Proxy + Create a proxy instance and set the backend to NumPy: >>> from array_api_compat import numpy as apc_np @@ -320,8 +334,8 @@ def set_backend( >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) >>> array = xp.arange(5) - >>> print(type(array)) - + >>> type(array) + numpy.ndarray Now switch to a PyTorch backend: @@ -329,9 +343,9 @@ def set_backend( >>> >>> xp = _Proxy("torch") >>> xp.set_backend(apc_torch) - >>> array = xp.arange(5) - >>> print(type(array)) - + >>> tensor = xp.arange(5) + >>> type(tensor) + torch.Tensor """ @@ -360,6 +374,8 @@ def get_float_dtype( Examples -------- + >>> from deeptrack.backend._config import _Proxy + Create a proxy instance and set the backend to NumPy: >>> from array_api_compat import numpy as apc_np @@ -367,13 +383,11 @@ def get_float_dtype( >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) - >>> dtype = xp.get_float_dtype() - >>> print(dtype) - float64 + >>> xp.get_float_dtype() + dtype('float64') - >>> dtype = xp.get_float_dtype("float32") - >>> print(dtype) - float32 + >>> xp.get_float_dtype("float32") + dtype('float32') Now switch to a PyTorch backend: @@ -382,12 +396,10 @@ def get_float_dtype( >>> xp = _Proxy("torch") >>> xp.set_backend(apc_torch) - >>> dtype = xp.get_float_dtype() - >>> print(dtype) + >>> xp.get_float_dtype() torch.float32 - >>> dtype = xp.get_float_dtype("float32") - >>> print(dtype) + >>> xp.get_float_dtype("float32") torch.float32 """ @@ -418,6 +430,8 @@ def get_int_dtype( Examples -------- + >>> from deeptrack.backend._config import _Proxy + Create a proxy instance and set the backend to NumPy: >>> from array_api_compat import numpy as apc_np @@ -425,26 +439,23 @@ def get_int_dtype( >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) - >>> dtype = xp.get_int_dtype() - >>> print(dtype) - int64 + >>> xp.get_int_dtype() + dtype('int64') - >>> dtype = xp.get_int_dtype("int32") - >>> print(dtype) - int32 + >>> xp.get_int_dtype("int32") + dtype('int32') Now switch to a PyTorch backend: >>> from array_api_compat import torch as apc_torch + >>> >>> xp = _Proxy("torch") >>> xp.set_backend(apc_torch) - >>> dtype = xp.get_int_dtype() - >>> print(dtype) + >>> xp.get_int_dtype() torch.int64 - >>> dtype = xp.get_int_dtype("int32") - >>> print(dtype) + >>> xp.get_int_dtype("int32") torch.int32 """ @@ -475,6 +486,8 @@ def get_complex_dtype( Examples -------- + >>> from deeptrack.backend._config import _Proxy + Create a proxy instance and set the backend to NumPy: >>> from array_api_compat import numpy as apc_np @@ -482,13 +495,11 @@ def get_complex_dtype( >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) - >>> dtype = xp.get_complex_dtype() - >>> print(dtype) - complex128 + >>> xp.get_complex_dtype() + dtype('complex128') - >>> dtype = xp.get_complex_dtype("complex64") - >>> print(dtype) - complex64 + >>> xp.get_complex_dtype("complex64") + dtype('complex64') Now switch to a PyTorch backend: @@ -497,12 +508,10 @@ def get_complex_dtype( >>> xp = _Proxy("torch") >>> xp.set_backend(apc_torch) - >>> dtype = xp.get_complex_dtype() - >>> print(dtype) + >>> xp.get_complex_dtype() torch.complex64 - >>> dtype = xp.get_complex_dtype("complex64") - >>> print(dtype) + >>> xp.get_complex_dtype("complex64") torch.complex64 """ @@ -533,6 +542,8 @@ def get_bool_dtype( Examples -------- + >>> from deeptrack.backend._config import _Proxy + Create a proxy instance and set the backend to NumPy: >>> from array_api_compat import numpy as apc_np @@ -540,13 +551,11 @@ def get_bool_dtype( >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) - >>> dtype = xp.get_bool_dtype() - >>> print(dtype) - bool + >>> xp.get_bool_dtype() + dtype('bool') - >>> dtype = xp.get_bool_dtype(dtype="bool") - >>> print(dtype) - bool + >>> xp.get_bool_dtype(dtype="bool") + dtype('bool') Now switch to a PyTorch backend: @@ -555,12 +564,10 @@ def get_bool_dtype( >>> xp = _Proxy("torch") >>> xp.set_backend(apc_torch) - >>> dtype = xp.get_bool_dtype() - >>> print(dtype) + >>> xp.get_bool_dtype() torch.bool - >>> dtype = xp.get_bool_dtype(dtype="bool") - >>> print(dtype) + >>> xp.get_bool_dtype(dtype="bool") torch.bool """ @@ -588,24 +595,24 @@ def __getattr__( Examples -------- + >>> from deeptrack.backend._config import _Proxy + Access NumPy's arange function transparently through the proxy: >>> from array_api_compat import numpy as apc_np >>> >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) - >>> array = xp.arange(4) - >>> print(array) - [0 1 2 3] + >>> xp.arange(4) + array([0, 1, 2, 3]) Now switch to a PyTorch backend: >>> from array_api_compat import torch as apc_torch >>> - >>> xp = ._Proxy("torch") + >>> xp = _Proxy("torch") >>> xp.set_backend(apc_torch) - >>> array = xp.arange(4) - >>> print(array) + >>> xp.arange(4) tensor([0, 1, 2, 3]) Analogously, you can access any attribute or function available in the @@ -625,14 +632,17 @@ def __dir__(self: _Proxy) -> list[str]: Examples -------- + >>> from deeptrack.backend._config import _Proxy + List the attributes (functions, constants, etc.) in the NumPy backend: >>> from array_api_compat import numpy as apc_np >>> >>> xp = _Proxy("numpy") >>> xp.set_backend(apc_np) - >>> attrs_numpy = dir(xp) - >>> print(attrs_numpy) + >>> dir(xp) + ['ALLOW_THREADS', + ...] List the attributes in the PyTorch backend: @@ -640,8 +650,9 @@ def __dir__(self: _Proxy) -> list[str]: >>> >>> xp = _Proxy("torch") >>> xp.set_backend(apc_torch) - >>> attrs_torch = dir(xp) - >>> print(attrs_torch) + >>> dir(xp) + ['AVG', + ...] """ @@ -710,19 +721,20 @@ class Config: >>> from deeptrack.backend import config - >>> print(config.get_backend()) + >>> config.get_backend() 'numpy' - >>> print(config.get_device()) + + >>> config.get_device() 'cpu' Set the backend to PyTorch and device to GPU: >>> config.set_backend_torch() - >>> print(config.get_backend()) + >>> config.get_backend() 'torch' >>> config.set_device("cuda") - >>> print(config.get_device()) + >>> config.get_device() 'cuda' Use the xp proxy to create arrays/tensors: @@ -731,25 +743,25 @@ class Config: >>> config.set_backend_numpy() >>> array = xp.arange(5) - >>> print(type(array)) - + >>> type(array) + numpy.ndarray >>> config.set_backend_torch() >>> tensor = xp.arange(5) - >>> print(type(tensor)) - + >>> type(tensor) + torch.Tensor Temporarily switch backend using a context manager: >>> config.set_backend("torch") - >>> print(config.get_backend()) + >>> config.get_backend() 'torch' >>> with config.with_backend("numpy"): ... print(config.get_backend()) - 'numpy' + numpy - >>> print(config.get_backend()) + >>> config.get_backend() 'torch' Use a torch.device object directly: @@ -758,7 +770,7 @@ class Config: >>> >>> config.set_backend_torch() >>> config.set_device(torch.device("cuda:0")) - >>> print(config.get_device()) + >>> config.get_device() device(type='cuda', index=0) """ @@ -804,37 +816,38 @@ def set_device( Set device to CPU (works with both NumPy and PyTorch backends): >>> config.set_device("cpu") - >>> print(config.get_device()) - cpu + >>> config.get_device() + 'cpu' Set device to GPU (requires PyTorch backend): >>> config.set_backend_torch() >>> config.set_device("cuda") - >>> print(config.get_device()) - cuda + >>> config.get_device() + 'cuda' Use a specific CUDA device (PyTorch backend): >>> import torch + >>> >>> config.set_backend_torch() >>> config.set_device(torch.device("cuda:0")) - >>> print(config.get_device()) + >>> config.get_device() device(type='cuda', index=0) Set device to Apple Silicon GPU (PyTorch backend on Macs): >>> config.set_backend_torch() >>> config.set_device("mps") - >>> print(config.get_device()) - mps + >>> config.get_device() + 'mps' Attempting to set a GPU device with NumPy backend (should be avoided): >>> config.set_backend_numpy() >>> config.set_device("cuda") - >>> print(config.get_device()) - cuda + >>> config.get_device() + 'cuda' Computation will still run on CPU, since NumPy does not support GPU. @@ -878,15 +891,16 @@ def set_backend_numpy(self: Config) -> None: Set the backend to NumPy: >>> config.set_backend_numpy() - >>> print(config.get_backend()) + >>> config.get_backend() 'numpy' NumPy backend enables use of standard NumPy arrays via the xp proxy: >>> from deeptrack.backend import xp + >>> >>> array = xp.arange(5) - >>> print(type(array)) - + >>> type(array) + numpy.ndarray """ @@ -904,16 +918,16 @@ def set_backend_torch(self: Config) -> None: Set the backend to PyTorch: >>> config.set_backend_torch() - >>> print(config.get_backend()) + >>> config.get_backend() 'torch' PyTorch backend enables use of PyTorch tensors via the xp proxy: >>> from deeptrack.backend import xp - + >>> >>> tensor = xp.arange(5) - >>> print(type(tensor)) - + >>> type(tensor) + torch.Tensor """ @@ -939,13 +953,13 @@ def set_backend( Set the backend to NumPy: >>> config.set_backend("numpy") - >>> print(config.get_backend()) + >>> config.get_backend() 'numpy' Set the backend to PyTorch: >>> config.set_backend("torch") - >>> print(config.get_backend()) + >>> config.get_backend() 'torch' Switch between backends as needed in your workflow using the xp proxy: @@ -954,13 +968,13 @@ def set_backend( >>> config.set_backend("numpy") >>> array = xp.arange(4) - >>> print(type(array)) - + >>> type(array) + numpy.ndarray >>> config.set_backend("torch") >>> tensor = xp.arange(4) - >>> print(type(tensor)) - + >>> type(tensor) + torch.Tensor """ @@ -1023,31 +1037,31 @@ def with_backend( Temporarily switch to the NumPy backend for a block of code: >>> config.set_backend("torch") - >>> print(config.get_backend()) + >>> config.get_backend() 'torch' >>> with config.with_backend("numpy"): ... print(config.get_backend()) - 'numpy' + numpy - >>> print(config.get_backend()) + >>> config.get_backend() 'torch' Temporarily switch to the PyTorch backend inside a function: >>> from deeptrack.backend import xp - >>> config.set_backend("numpy") + >>> config.set_backend("numpy")config.set_backend("numpy") >>> def do_torch_operation(): ... with config.with_backend("torch"): ... return xp.arange(3) >>> tensor = do_torch_operation() - >>> print(type(tensor)) - + >>> type(tensor) + torch.Tensor - >>> print(config.get_backend()) + >>> config.get_backend() 'numpy' """ From 3ae3962146e8471a989b9d68beffd32040f98c4e Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 15:06:41 +0200 Subject: [PATCH 08/54] Update core.py --- deeptrack/backend/core.py | 358 +++++++++++++++++++++++--------------- 1 file changed, 222 insertions(+), 136 deletions(-) diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index 083df24bf..c58b03052 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -252,15 +252,15 @@ class DeepTrackDataDict: Mark all stored data objects as invalid. `validate() -> None` Mark all stored data objects as valid. - `valid_index(_ID : tuple[int, ...]) -> bool` + `valid_index(_ID: tuple[int, ...]) -> bool` Check if the given _ID is valid for the current configuration. - `create_index(_ID : tuple[int, ...] = ()) -> None` + `create_index(_ID: tuple[int, ...] = ()) -> None` Create an entry for the given _ID if it does not exist. - `__getitem__(_ID : tuple[int, ...]) -> DeepTrackDataObject or dict[tuple[int, ...], DeepTrackDataObject]` + `__getitem__(_ID: tuple[int, ...]) -> DeepTrackDataObject or dict[tuple[int, ...], DeepTrackDataObject]` Retrieve data associated with the _ID. Can return a `DeepTrackDataObject` or a dict of matching entries if `_ID` is shorter than `keylength`. - `__contains__(_ID : tuple[int, ...]) -> bool` + `__contains__(_ID: tuple[int, ...]) -> bool` Check whether the given _ID exists in the dictionary. Example @@ -493,7 +493,7 @@ def __getitem__( Parameters ---------- - _ID: Tuple[int, ...] + _ID: tuple[int, ...] The _ID for the requested data. Returns @@ -548,7 +548,7 @@ def __contains__( Parameters ---------- - _ID : tuple[int, ...] + _ID: tuple[int, ...] The _ID to check. Returns @@ -585,24 +585,24 @@ class DeepTrackNode: Attributes ---------- - data : DeepTrackDataDict + data: DeepTrackDataDict Dictionary-like object for storing data, indexed by tuples of integers. - children : WeakSet[DeepTrackNode] + children: WeakSet[DeepTrackNode] Nodes that depend on this node (its children, grandchildren, etc.). - dependencies : WeakSet[DeepTrackNode] + dependencies: WeakSet[DeepTrackNode] Nodes on which this node depends (its parents, grandparents, etc.). - _action : Callable + _action: Callable The function or lambda-function to compute the node value. - _accepts_ID : bool + _accepts_ID: bool Whether `action` accepts an input _ID. - _all_children : Set[DeepTrackNode] + _all_children: set[DeepTrackNode] All nodes in the subtree rooted at the node, including the node itself. - _citations : List[str] + _citations: list[str] Citations associated with this node. Methods ------- - action : property + action: Property Gets or sets the computation function for the node. add_child(child: DeepTrackNode) -> DeepTrackNode Adds a child node that depends on this node. @@ -610,36 +610,36 @@ class DeepTrackNode: add_dependency(parent: DeepTrackNode) -> DeepTrackNode Adds a dependency, making this node depend on the parent node. It also sets this node as a child of the parent node. - store(data: Any, _ID: Tuple[int, ...] = ()) -> DeepTrackNode + store(data: Any, _ID: tuple[int, ...] = ()) -> DeepTrackNode Stores computed data for the given `_ID`. - is_valid(_ID: Tuple[int, ...] = ()) -> bool + is_valid(_ID: tuple[int, ...] = ()) -> bool Checks if the data for the given `_ID` is valid. - valid_index(_ID: Tuple[int, ...]) -> bool + valid_index(_ID: tuple[int, ...]) -> bool Checks if the given `_ID` is valid for this node. - invalidate(_ID: Tuple[int, ...] = ()) -> DeepTrackNode + invalidate(_ID: tuple[int, ...] = ()) -> DeepTrackNode Invalidates the data for the given `_ID` and all child nodes. - validate(_ID: Tuple[int, ...] = ()) -> DeepTrackNode + validate(_ID: tuple[int, ...] = ()) -> DeepTrackNode Validates the data for the given `_ID`, marking it as up-to-date, but not its children. update() -> DeepTrackNode Resets the data. - set_value(value: Any, _ID: Tuple[int, ...] = ()) -> DeepTrackNode + set_value(value: Any, _ID: tuple[int, ...] = ()) -> DeepTrackNode Sets a value for the given `_ID`. If the new value differs from the current value, the node is invalidated to ensure dependencies are recomputed. - previous(_ID: Tuple[int, ...] = ()) -> Any + previous(_ID: tuple[int, ...] = ()) -> Any Returns the previously stored value for the given `_ID` without recomputing it. - recurse_children(memory: Optional[Set[DeepTrackNode]] = None) -> Set[DeepTrackNode] + recurse_children(memory: set[DeepTrackNode] | None = None) -> set[DeepTrackNode] Returns all child nodes in the dependency tree rooted at this node. - recurse_dependencies(memory: Optional[List[DeepTrackNode]] = None) -> Iterator[DeepTrackNode] + recurse_dependencies(memory: list[DeepTrackNode] | None = None) -> Iterator[DeepTrackNode] Yields all nodes that this node depends on, traversing dependencies. - get_citations() -> Set[str] + get_citations() -> set[str] Returns a set of citations for this node and its dependencies. - __call__(_ID: Tuple[int, ...] = ()) -> Any + __call__(_ID: tuple[int, ...] = ()) -> Any Evaluates the node's computation for the given `_ID`, recomputing if necessary. - current_value(_ID: Tuple[int, ...] = ()) -> Any + current_value(_ID: tuple[int, ...] = ()) -> Any Returns the currently stored value for the given `_ID` without recomputation. __hash__() -> int @@ -689,17 +689,17 @@ class DeepTrackNode: # Attributes. data: DeepTrackDataDict - children: WeakSet['DeepTrackNode'] - dependencies: WeakSet['DeepTrackNode'] + children: WeakSet[DeepTrackNode] + dependencies: WeakSet[DeepTrackNode] _action: Callable[..., Any] _accepts_ID: bool - _all_children: Set['DeepTrackNode'] + _all_children: set[DeepTrackNode] # Citations associated with DeepTrack2. - _citations: List[str] = [CITATION_MIDTVEDT2021QUANTITATIVE] + _citations: list[str] = [CITATION_MIDTVEDT2021QUANTITATIVE] @property - def action(self) -> Callable[..., Any]: + def action(self: DeepTrackNode) -> Callable[..., Any]: """Callable: The function that computes this node’s value. When accessed, returns the current action. This is often a function or @@ -711,12 +711,15 @@ def action(self) -> Callable[..., Any]: return self._action @action.setter - def action(self, value: Callable[..., Any]) -> None: + def action( + self: DeepTrackNode, + value: Callable[..., Any], + ) -> None: """Set the action used to compute this node’s value. Parameters ---------- - value : Callable[..., Any] + value: Callable[..., Any] A function or lambda to be used for computing the node’s value. If the function’s signature includes `_ID`, this node will pass `_ID` when calling `action`. @@ -727,19 +730,19 @@ def action(self, value: Callable[..., Any]) -> None: self._accepts_ID = "_ID" in get_kwarg_names(value) def __init__( - self, - action: Optional[Callable[..., Any]] = None, + self: DeepTrackNode, + action: Callable[..., Any] | None = None, **kwargs: Any, ): """Initialize a new DeepTrackNode. Parameters ---------- - action : Callable or Any, optional + action: Callable or Any, optional Action to compute this node’s value. If not provided, uses a no-op action (lambda: None). - **kwargs : dict + **kwargs: dict Additional arguments for subclasses or extended functionality. """ @@ -768,7 +771,10 @@ def __init__( self._all_children = set() self._all_children.add(self) - def add_child(self, child: 'DeepTrackNode') -> 'DeepTrackNode': + def add_child( + self: DeepTrackNode, + child: DeepTrackNode, + ) -> DeepTrackNode: """Add a child node to the current node. Adding a child also updates `_all_children` for this node and all @@ -777,12 +783,12 @@ def add_child(self, child: 'DeepTrackNode') -> 'DeepTrackNode': Parameters ---------- - child : DeepTrackNode + child: DeepTrackNode The child node that depends on this node. Returns ------- - self : DeepTrackNode + self: DeepTrackNode Returns the current node for chaining. """ @@ -802,18 +808,21 @@ def add_child(self, child: 'DeepTrackNode') -> 'DeepTrackNode': return self - def add_dependency(self, parent: 'DeepTrackNode') -> 'DeepTrackNode': + def add_dependency( + self: DeepTrackNode, + parent: DeepTrackNode, + ) -> DeepTrackNode: """Adds a dependency, making this node depend on a parent node. Parameters ---------- - parent : DeepTrackNode + parent: DeepTrackNode The parent node that this node depends on. If `parent` changes, this node’s data may become invalid. Returns ------- - self : DeepTrackNode + self: DeepTrackNode Returns the current node for chaining. """ @@ -824,37 +833,44 @@ def add_dependency(self, parent: 'DeepTrackNode') -> 'DeepTrackNode': return self - def store(self, data: Any, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': + def store( + self: DeepTrackNode, + data: Any, + _ID: tuple[int, ...] = (), + ) -> DeepTrackNode: """Store computed data in this node. Parameters ---------- - data : Any + data: Any The data to be stored. - _ID : Tuple[int, ...], optional + _ID: tuple[int, ...], optional The index for this data. Default is the empty tuple (), indicating a root-level entry. Returns ------- - self : DeepTrackNode + self: DeepTrackNode Returns the current node for chaining. """ # Create the index if necessary, then store data in it. self.data.create_index(_ID) - + self.data[_ID].store(data) return self - def is_valid(self, _ID: Tuple[int, ...] = ()) -> bool: + def is_valid( + self: DeepTrackNode, + _ID: tuple[int, ...] = (), + ) -> bool: """Check if data for the given _ID is valid. Parameters ---------- - _ID : Tuple[int, ...], optional + _ID: tuple[int, ...], optional The _ID to check validity for. Returns @@ -869,12 +885,15 @@ def is_valid(self, _ID: Tuple[int, ...] = ()) -> bool: except (KeyError, AttributeError): return False - def valid_index(self, _ID: Tuple[int, ...]) -> bool: + def valid_index( + self: DeepTrackNode, + _ID: tuple[int, ...], + ) -> bool: """Check if _ID is a valid index for this node’s data. Parameters ---------- - _ID : Tuple[int, ...] + _ID: tuple[int, ...] The _ID to validate. Returns @@ -886,18 +905,21 @@ def valid_index(self, _ID: Tuple[int, ...]) -> bool: return self.data.valid_index(_ID) - def invalidate(self, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': + def invalidate( + self: DeepTrackNode, + _ID: tuple[int, ...] = (), + ) -> DeepTrackNode: """Mark this node’s data and all its children’s data as invalid. Parameters ---------- - _ID : Tuple[int, ...], optional + _ID: tuple[int, ...], optional The _ID to invalidate. Default is empty tuple, indicating potentially the full dataset. Returns ------- - self : DeepTrackNode + self: DeepTrackNode Returns the current node for chaining. Note @@ -914,17 +936,20 @@ def invalidate(self, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': return self - def validate(self, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': + def validate( + self: DeepTrackNode, + _ID: tuple[int, ...] = (), + ) -> DeepTrackNode: """Mark this node’s data as valid. Parameters ---------- - _ID : Tuple[int, ...], optional + _ID: tuple[int, ...], optional The _ID to validate. Default is empty tuple. Returns ------- - self : DeepTrackNode + self: DeepTrackNode """ @@ -932,7 +957,7 @@ def validate(self, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': return self - def update(self) -> 'DeepTrackNode': + def update(self: DeepTrackNode) -> DeepTrackNode: """Reset data in all children. This method resets `data` for all children of each dependency, @@ -941,7 +966,7 @@ def update(self) -> 'DeepTrackNode': Returns ------- - self : DeepTrackNode + self: DeepTrackNode Returns the current node for chaining. """ @@ -957,7 +982,11 @@ def update(self) -> 'DeepTrackNode': return self - def set_value(self, value, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': + def set_value( + self: DeepTrackNode, + value: Any, + _ID: tuple[int, ...] = (), + ) -> DeepTrackNode: """Set a value for this node’s data at _ID. If the value is different from the currently stored one (or if it is @@ -965,14 +994,14 @@ def set_value(self, value, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': Parameters ---------- - value : Any + value: Any The value to store. - _ID : Tuple[int, ...], optional + _ID: tuple[int, ...], optional The _ID at which to store the value. Returns ------- - self : DeepTrackNode + self: DeepTrackNode Returns the current node for chaining. """ @@ -988,12 +1017,15 @@ def set_value(self, value, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': return self - def previous(self, _ID: Tuple[int, ...] = ()) -> Any: + def previous( + self: DeepTrackNode, + _ID: tuple[int, ...] = (), + ) -> Any: """Retrieve the previously stored value at _ID without recomputing. Parameters ---------- - _ID : Tuple[int, ...], optional + _ID: tuple[int, ...], optional The _ID for which to retrieve the previous value. Returns @@ -1010,14 +1042,14 @@ def previous(self, _ID: Tuple[int, ...] = ()) -> Any: return [] def recurse_children( - self, - memory: Optional[Set['DeepTrackNode']] = None, - ) -> Set['DeepTrackNode']: + self: DeepTrackNode, + memory: set[DeepTrackNode] | None = None, + ) -> set[DeepTrackNode]: """Return all children of this node. Parameters ---------- - memory : set, optional + memory: set, optional Memory set to track visited nodes (not used directly here). Returns @@ -1030,14 +1062,14 @@ def recurse_children( return self._all_children def old_recurse_children( - self, - memory: Optional[List['DeepTrackNode']] = None, - ) -> Iterator['DeepTrackNode']: + self: DeepTrackNode, + memory: list[DeepTrackNode] | None = None, + ) -> Iterator[DeepTrackNode]: """Legacy recursive method for traversing children. Parameters ---------- - memory : list, optional + memory: list, optional A list to remember visited nodes, ensuring that each node is yielded only once. @@ -1071,14 +1103,14 @@ def old_recurse_children( yield from child.recurse_children(memory=memory) def recurse_dependencies( - self, - memory: Optional[List['DeepTrackNode']] = None, + self: DeepTrackNode, + memory: list[DeepTrackNode] | None = None, ) -> Iterator['DeepTrackNode']: """Yield all dependencies of this node, ensuring each is visited once. Parameters ---------- - memory : list, optional + memory: list, optional A list of visited nodes to avoid repeated visits or infinite loops. Yields @@ -1106,7 +1138,7 @@ def recurse_dependencies( for dependency in self.dependencies: yield from dependency.recurse_dependencies(memory=memory) - def get_citations(self) -> Set[str]: + def get_citations(self: DeepTrackNode) -> set[str]: """Get citations from this node and all its dependencies. It gathers citations from this node and all nodes that it depends on. @@ -1134,7 +1166,10 @@ def get_citations(self) -> Set[str]: return citations - def __call__(self, _ID: Tuple[int, ...] = ()) -> Any: + def __call__( + self: DeepTrackNode, + _ID: tuple[int, ...] = (), + ) -> Any: """Evaluate this node at _ID. If the data at `_ID` is valid, it returns the stored value. Otherwise, @@ -1142,7 +1177,7 @@ def __call__(self, _ID: Tuple[int, ...] = ()) -> Any: Parameters ---------- - _ID : Tuple[int, ...], optional + _ID: tuple[int, ...], optional The _ID at which to evaluate the node’s action. Returns @@ -1169,12 +1204,15 @@ def __call__(self, _ID: Tuple[int, ...] = ()) -> Any: return self.current_value(_ID) - def current_value(self, _ID: Tuple[int, ...] = ()) -> Any: + def current_value( + self: DeepTrackNode, + _ID: tuple[int, ...] = (), + ) -> Any: """Retrieve the currently stored value at _ID. Parameters ---------- - _ID : Tuple[int, ...], optional + _ID: tuple[int, ...], optional The _ID at which to retrieve the current value. Returns @@ -1186,7 +1224,7 @@ def current_value(self, _ID: Tuple[int, ...] = ()) -> Any: return self.data[_ID].current_value() - def __hash__(self) -> int: + def __hash__(self: DeepTrackNode) -> int: """Return a unique hash for this node. Uses the node’s `id` to ensure uniqueness. @@ -1195,12 +1233,15 @@ def __hash__(self) -> int: return id(self) - def __getitem__(self, idx: Any) -> 'DeepTrackNode': + def __getitem__( + self: DeepTrackNode, + idx: Any, + ) -> DeepTrackNode: """Allow indexing into the node’s computed data. Parameters ---------- - idx : Any + idx: Any The index applied to the result of evaluating this node. Returns @@ -1230,7 +1271,10 @@ def __getitem__(self, idx: Any) -> 'DeepTrackNode': # and `other`. The operators are applied lazily and will be computed only # when the resulting node is evaluated. - def __add__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __add__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Add node to another node or value. Creates a new `DeepTrackNode` representing the addition of the values @@ -1238,7 +1282,7 @@ def __add__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The node or value to add. Returns @@ -1250,7 +1294,10 @@ def __add__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__add__, self, other) - def __radd__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __radd__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Add other value to node (right-hand). Creates a new `DeepTrackNode` representing the addition of another @@ -1258,7 +1305,7 @@ def __radd__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The value or node to add. Returns @@ -1270,7 +1317,10 @@ def __radd__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__add__, other, self) - def __sub__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __sub__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Subtract another node or value from node. Creates a new `DeepTrackNode` representing the subtraction of the @@ -1279,7 +1329,7 @@ def __sub__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The node or value to subtract. Returns @@ -1292,7 +1342,10 @@ def __sub__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__sub__, self, other) - def __rsub__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __rsub__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Subtract node from other value (right-hand). Creates a new `DeepTrackNode` representing the subtraction of the value @@ -1300,7 +1353,7 @@ def __rsub__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The value or node to subtract from. Returns @@ -1313,7 +1366,10 @@ def __rsub__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__sub__, other, self) - def __mul__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __mul__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Multiply node by another node or value. Creates a new `DeepTrackNode` representing the multiplication of the @@ -1322,7 +1378,7 @@ def __mul__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The node or value to multiply by. Returns @@ -1335,7 +1391,10 @@ def __mul__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__mul__, self, other) - def __rmul__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __rmul__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Multiply other value by node (right-hand). Creates a new `DeepTrackNode` representing the multiplication of @@ -1344,7 +1403,7 @@ def __rmul__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The value or node to multiply. Returns @@ -1356,9 +1415,9 @@ def __rmul__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__mul__, other, self) def __truediv__( - self, - other: Union['DeepTrackNode', Any], - ) -> 'DeepTrackNode': + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Divide node by another node or value. Creates a new `DeepTrackNode` representing the division of the value @@ -1366,7 +1425,7 @@ def __truediv__( Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The node or value to divide by. Returns @@ -1379,9 +1438,9 @@ def __truediv__( return _create_node_with_operator(operator.__truediv__, self, other) def __rtruediv__( - self, - other: Union['DeepTrackNode', Any], - ) -> 'DeepTrackNode': + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Divide other value by node (right-hand). Creates a new `DeepTrackNode` representing the division of another @@ -1389,7 +1448,7 @@ def __rtruediv__( Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The value or node to divide. Returns @@ -1402,9 +1461,9 @@ def __rtruediv__( return _create_node_with_operator(operator.__truediv__, other, self) def __floordiv__( - self, - other: Union['DeepTrackNode', Any], - ) -> 'DeepTrackNode': + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Perform floor division of node by another node or value. Creates a new `DeepTrackNode` representing the floor division of the @@ -1413,7 +1472,7 @@ def __floordiv__( Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The node or value to divide by. Returns @@ -1427,9 +1486,9 @@ def __floordiv__( return _create_node_with_operator(operator.__floordiv__, self, other) def __rfloordiv__( - self, - other: Union['DeepTrackNode', Any], - ) -> 'DeepTrackNode': + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Perform floor division of other value by node (right-hand). Creates a new `DeepTrackNode` representing the floor division of @@ -1438,7 +1497,7 @@ def __rfloordiv__( Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The value or node to divide. Returns @@ -1451,7 +1510,10 @@ def __rfloordiv__( return _create_node_with_operator(operator.__floordiv__, other, self) - def __lt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __lt__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Check if node is less than another node or value. Creates a new `DeepTrackNode` representing the comparison of this node @@ -1459,7 +1521,7 @@ def __lt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The node or value to compare with. Returns @@ -1472,7 +1534,10 @@ def __lt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__lt__, self, other) - def __rlt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __rlt__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Check if other value is less than node (right-hand). Creates a new `DeepTrackNode` representing the comparison of another @@ -1480,7 +1545,7 @@ def __rlt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The value or node to compare. Returns @@ -1493,7 +1558,10 @@ def __rlt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__lt__, other, self) - def __gt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __gt__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Check if node is greater than another node or value. Creates a new `DeepTrackNode` representing the comparison of this node @@ -1501,7 +1569,7 @@ def __gt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The node or value to compare with. Returns @@ -1514,7 +1582,10 @@ def __gt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__gt__, self, other) - def __rgt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __rgt__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Check if other value is greater than node (right-hand). Creates a new `DeepTrackNode` representing the comparison of another @@ -1522,7 +1593,7 @@ def __rgt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The value or node to compare. Returns @@ -1535,7 +1606,10 @@ def __rgt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__gt__, other, self) - def __le__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __le__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Check if node is less than or equal to another node or value. Creates a new `DeepTrackNode` representing the comparison of this node @@ -1543,7 +1617,7 @@ def __le__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The node or value to compare with. Returns @@ -1556,7 +1630,10 @@ def __le__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__le__, self, other) - def __rle__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __rle__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Check if other value is less than or equal to node (right-hand). Creates a new `DeepTrackNode` representing the comparison of another @@ -1564,7 +1641,7 @@ def __rle__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The value or node to compare. Returns @@ -1577,7 +1654,10 @@ def __rle__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__le__, other, self) - def __ge__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __ge__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Check if node is greater than or equal to another node or value. Creates a new `DeepTrackNode` representing the comparison of this node @@ -1586,7 +1666,7 @@ def __ge__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The node or value to compare with. Returns @@ -1599,7 +1679,10 @@ def __ge__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__ge__, self, other) - def __rge__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + def __rge__( + self: DeepTrackNode, + other: DeepTrackNode | Any, + ) -> DeepTrackNode: """Check if other value is greater than or equal to node (right-hand). Creates a new `DeepTrackNode` representing the comparison of another @@ -1608,7 +1691,7 @@ def __rge__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': Parameters ---------- - other : DeepTrackNode or Any + other: DeepTrackNode or Any The value or node to compare. Returns @@ -1622,7 +1705,10 @@ def __rge__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': return _create_node_with_operator(operator.__ge__, other, self) -def _equivalent(a: Any, b: Any) -> bool: +def _equivalent( + a: Any, + b: Any, +) -> bool: """Check if two objects are equivalent. This internal helper function provides a basic implementation to determine @@ -1634,9 +1720,9 @@ def _equivalent(a: Any, b: Any) -> bool: Parameters ---------- - a : Any + a: Any The first object to compare. - b : Any + b: Any The second object to compare. Returns @@ -1662,7 +1748,7 @@ def _create_node_with_operator( op: Callable, a: Any, b: Any, -) -> 'DeepTrackNode': +) -> DeepTrackNode: """Create a new computation node using a given operator and operands. This internal helper function constructs a `DeepTrackNode` obtained from @@ -1679,11 +1765,11 @@ def _create_node_with_operator( Parameters ---------- - op : Callable + op: Callable The operator function. - a : Any + a: Any First operand. If not a `DeepTrackNode`, it will be wrapped in one. - b : Any + b: Any Second operand. If not a `DeepTrackNode`, it will be wrapped in one. Returns From 740f962b82e214e709a6dd78914f75f50913b3ab Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 15:12:17 +0200 Subject: [PATCH 09/54] Update core.py --- deeptrack/backend/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index c58b03052..13f254f9b 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -164,8 +164,8 @@ class DeepTrackDataObject: def __init__(self: DeepTrackDataObject): """Initialize the container without data. - It sets the `data` and `valid` attributes are set to their default - values `None` and `False`. + It sets the `data` and `valid` attributes to their default values + `None` and `False`. """ @@ -242,7 +242,7 @@ class DeepTrackDataDict: The length of the _IDs currently stored. Set when the first entry is created. If `None`, no entries have been created yet, and any _ID length is valid. - dict: dict[tuple[int, ...], DeepTrackDataObject] + dict: dict[tuple[int, ...], DeepTrackDataObject] or {} A dictionary mapping tuples of integers (_IDs) to `DeepTrackDataObject` instances. @@ -351,7 +351,7 @@ class DeepTrackDataDict: """ - keylength: int + keylength: int | None dict: dict[tuple[int, ...], DeepTrackDataObject] def __init__(self: DeepTrackDataDict): From 8c28553ee14e26fb1f99c34e703b84af83d6a418 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 15:29:25 +0200 Subject: [PATCH 10/54] Update core.py --- deeptrack/backend/core.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index 13f254f9b..9359eb50e 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -1718,6 +1718,10 @@ def _equivalent( - If both `a` and `b` are empty lists, they are considered equivalent. Additional cases can be implemented as needed to refine this behavior. + For immutable built-in types like empty tuples, integers, and `None`, Python + may reuse the same object in memory. Thus, `a is b` may return True even if + the objects are created separately. + Parameters ---------- a: Any @@ -1730,6 +1734,26 @@ def _equivalent( bool `True` if the objects are equivalent, `False` otherwise. + Examples + -------- + >>> from deeptrack.backend.core import _equivalent + + >>> _equivalent([], []) + True + + >>> a = [1, 2] + >>> _equivalent(a, a) + True + + >>> _equivalent([1], [1]) + False + + >>> _equivalent([], ()) + False + + >>> _equivalent(None, None) + True + """ # If a and b are the same object, return True. From 7c152bfb700ff486af868d4eb3f91a9c0322a9c1 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 15:29:28 +0200 Subject: [PATCH 11/54] Update test_core.py --- deeptrack/tests/backend/test_core.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/deeptrack/tests/backend/test_core.py b/deeptrack/tests/backend/test_core.py index d357913a9..21c408f1a 100644 --- a/deeptrack/tests/backend/test_core.py +++ b/deeptrack/tests/backend/test_core.py @@ -329,6 +329,33 @@ def test_DeepTrackNode_dependency_graph_with_ids(self): # 24 self.assertEqual(C_0_1_2, 24) + def test__equivalent(self): + # Identity check (same object) + a = [1, 2, 3] + self.assertTrue(core._equivalent(a, a)) + + # Both are empty lists (but not the same object) + self.assertTrue(core._equivalent([], [])) + a, b = [], [] + self.assertTrue(core._equivalent(a, b)) + + # Non-empty lists (not same object, not empty) + self.assertFalse(core._equivalent([1], [1])) + + # Empty list and None + self.assertFalse(core._equivalent([], None)) + + # Different types + self.assertFalse(core._equivalent(1, "1")) + + # Non-empty lists (same content, not same object) + a = [1] + b = [1] + self.assertFalse(core._equivalent(a, b)) + + # One empty list, one non-list empty container + self.assertFalse(core._equivalent([], ())) + if __name__ == "__main__": unittest.main() From aded0382adbdf97217f4b886390a08c6cba22445 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 17:00:38 +0200 Subject: [PATCH 12/54] Update core.py --- deeptrack/backend/core.py | 93 ++++++++++++++++++++------------------- 1 file changed, 48 insertions(+), 45 deletions(-) diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index 9359eb50e..6901dd36a 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -645,7 +645,7 @@ class DeepTrackNode: __hash__() -> int Returns a unique hash for this node. __getitem__(idx: Any) -> DeepTrackNode - Creates a new node that indexes into this node’s computed data. + Creates a new node that indexes into this node's computed data. Example ------- @@ -700,12 +700,12 @@ class DeepTrackNode: @property def action(self: DeepTrackNode) -> Callable[..., Any]: - """Callable: The function that computes this node’s value. + """Callable: The function that computes this node's value. - When accessed, returns the current action. This is often a function or - lambda-function that takes `_ID` as an optional parameter if + When accessed, it returns the current action. This is often a function + or lambda-function that takes `_ID` as an optional parameter if `_accepts_ID` is `True`. - + """ return self._action @@ -713,21 +713,21 @@ def action(self: DeepTrackNode) -> Callable[..., Any]: @action.setter def action( self: DeepTrackNode, - value: Callable[..., Any], + _action: Callable[..., Any], ) -> None: - """Set the action used to compute this node’s value. + """Set the action used to compute this node's value. Parameters ---------- - value: Callable[..., Any] - A function or lambda to be used for computing the node’s value. If - the function’s signature includes `_ID`, this node will pass `_ID` - when calling `action`. + _action: Callable[..., Any] + A function or lambda-function used for computing the node's value. + If the function's signature includes `_ID`, this node will pass + `_ID` when calling `action`. """ - self._action = value - self._accepts_ID = "_ID" in get_kwarg_names(value) + self._action = _action + self._accepts_ID = "_ID" in get_kwarg_names(_action) def __init__( self: DeepTrackNode, @@ -739,10 +739,10 @@ def __init__( Parameters ---------- action: Callable or Any, optional - Action to compute this node’s value. If not provided, uses a no-op + Action to compute this node's value. If not provided, uses a no-op action (lambda: None). - **kwargs: dict + **kwargs: dict[str, Any] Additional arguments for subclasses or extended functionality. """ @@ -750,7 +750,7 @@ def __init__( self.data = DeepTrackDataDict() self.children = WeakSet() self.dependencies = WeakSet() - self._action = lambda: None # Default no-op action. + self._action = lambda: None # Default no-op action # If action is provided, set it. # If it's callable, use it directly; @@ -789,7 +789,7 @@ def add_child( Returns ------- self: DeepTrackNode - Returns the current node for chaining. + Return the current node for chaining. """ @@ -801,7 +801,7 @@ def add_child( children = child._all_children.copy() children.add(child) - # Merge all these children into this node’s subtree. + # Merge all these children into this node's subtree. self._all_children = self._all_children.union(children) for parent in self.recurse_dependencies(): parent._all_children = parent._all_children.union(children) @@ -818,12 +818,12 @@ def add_dependency( ---------- parent: DeepTrackNode The parent node that this node depends on. If `parent` changes, - this node’s data may become invalid. + this node's data may become invalid. Returns ------- self: DeepTrackNode - Returns the current node for chaining. + Return the current node for chaining. """ @@ -845,13 +845,13 @@ def store( data: Any The data to be stored. _ID: tuple[int, ...], optional - The index for this data. Default is the empty tuple (), indicating - a root-level entry. + The index for this data. If the _ID does not exist, it creates it. + Default is the empty tuple (), indicating a root-level entry. Returns ------- self: DeepTrackNode - Returns the current node for chaining. + Return the current node for chaining. """ @@ -889,7 +889,7 @@ def valid_index( self: DeepTrackNode, _ID: tuple[int, ...], ) -> bool: - """Check if _ID is a valid index for this node’s data. + """Check if _ID is a valid index for this node's data. Parameters ---------- @@ -909,7 +909,7 @@ def invalidate( self: DeepTrackNode, _ID: tuple[int, ...] = (), ) -> DeepTrackNode: - """Mark this node’s data and all its children’s data as invalid. + """Mark this node's data and all its children's data as invalid. Parameters ---------- @@ -920,7 +920,7 @@ def invalidate( Returns ------- self: DeepTrackNode - Returns the current node for chaining. + Return the current node for chaining. Note ---- @@ -940,7 +940,7 @@ def validate( self: DeepTrackNode, _ID: tuple[int, ...] = (), ) -> DeepTrackNode: - """Mark this node’s data as valid. + """Mark this node's data as valid. Parameters ---------- @@ -967,12 +967,12 @@ def update(self: DeepTrackNode) -> DeepTrackNode: Returns ------- self: DeepTrackNode - Returns the current node for chaining. + Return the current node for chaining. """ - # Pre-instantiate memory for optimization used to avoid repeated - # processing of the same nodes. + # Pre-instantiate memory for optimization, + # used to avoid repeated processing of the same nodes. child_memory = [] # For each dependency, reset data in all of its children. @@ -987,7 +987,7 @@ def set_value( value: Any, _ID: tuple[int, ...] = (), ) -> DeepTrackNode: - """Set a value for this node’s data at _ID. + """Set a value for this node's data at _ID. If the value is different from the currently stored one (or if it is invalid), it will invalidate the old data before storing the new one. @@ -1002,14 +1002,14 @@ def set_value( Returns ------- self: DeepTrackNode - Returns the current node for chaining. + Return the current node for chaining. """ # Check if current value is equivalent. If not, invalidate and store # the new value. If set to same value, no need to invalidate. if not ( - self.is_valid(_ID=_ID) + self.is_valid(_ID=_ID) and _equivalent(value, self.data[_ID].current_value()) ): self.invalidate(_ID=_ID) @@ -1038,8 +1038,8 @@ def previous( if self.data.valid_index(_ID): return self.data[_ID].current_value() - else: - return [] + + return [] # If `_ID` is not a valid index def recurse_children( self: DeepTrackNode, @@ -1050,12 +1050,14 @@ def recurse_children( Parameters ---------- memory: set, optional - Memory set to track visited nodes (not used directly here). + Set of nodes that have already been visited (not used directly + here). Returns ------- set All nodes in the subtree rooted at this node, including itself. + """ # Simply return `_all_children` since it's maintained incrementally. @@ -1081,7 +1083,7 @@ def old_recurse_children( Notes ----- This method is kept for backward compatibility or debugging purposes. - + """ # On first call, instantiate memory. @@ -1105,7 +1107,7 @@ def old_recurse_children( def recurse_dependencies( self: DeepTrackNode, memory: list[DeepTrackNode] | None = None, - ) -> Iterator['DeepTrackNode']: + ) -> Iterator[DeepTrackNode]: """Yield all dependencies of this node, ensuring each is visited once. Parameters @@ -1146,7 +1148,7 @@ def get_citations(self: DeepTrackNode) -> set[str]: Returns ------- - Set[str] + set[str] Set of all citations relevant to this node and its dependency tree. """ @@ -1178,13 +1180,13 @@ def __call__( Parameters ---------- _ID: tuple[int, ...], optional - The _ID at which to evaluate the node’s action. + The _ID at which to evaluate the node's action. Returns ------- Any The computed or retrieved data for the given `_ID`. - + """ if self.is_valid(_ID): @@ -1227,7 +1229,7 @@ def current_value( def __hash__(self: DeepTrackNode) -> int: """Return a unique hash for this node. - Uses the node’s `id` to ensure uniqueness. + Uses the node's `id` to ensure uniqueness. """ @@ -1237,7 +1239,7 @@ def __getitem__( self: DeepTrackNode, idx: Any, ) -> DeepTrackNode: - """Allow indexing into the node’s computed data. + """Allow indexing into the node's computed data. Parameters ---------- @@ -1254,9 +1256,10 @@ def __getitem__( ----- This effectively creates a node that corresponds to `self(...)[idx]`, allowing you to select parts of the computed data dynamically. + """ - # Create a new node whose action indexes into this node’s result. + # Create a new node whose action indexes into this node's result. node = DeepTrackNode(lambda _ID=None: self(_ID=_ID)[idx]) self.add_child(node) @@ -1773,7 +1776,7 @@ def _create_node_with_operator( a: Any, b: Any, ) -> DeepTrackNode: - """Create a new computation node using a given operator and operands. + """Create a new computation node using the given operator and operands. This internal helper function constructs a `DeepTrackNode` obtained from the application of the specified operator to two operands. If the operands From 2fe8dd17c2e058b4a25aa48be871ad1f2b8c87c4 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 17:00:40 +0200 Subject: [PATCH 13/54] Update test_core.py --- deeptrack/tests/backend/test_core.py | 35 ++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/deeptrack/tests/backend/test_core.py b/deeptrack/tests/backend/test_core.py index 21c408f1a..b1b1292f9 100644 --- a/deeptrack/tests/backend/test_core.py +++ b/deeptrack/tests/backend/test_core.py @@ -356,6 +356,41 @@ def test__equivalent(self): # One empty list, one non-list empty container self.assertFalse(core._equivalent([], ())) + def test__create_node_with_operator(self): + import operator + + # Test with integers (should be wrapped automatically) + node = core._create_node_with_operator(operator.add, 2, 3) + self.assertIsInstance(node, core.DeepTrackNode) + self.assertEqual(node(), 5) + + # Test with DeepTrackNode operands (addition) + a = core.DeepTrackNode(lambda: 10) + b = core.DeepTrackNode(lambda: 7) + node2 = core._create_node_with_operator(operator.sub, a, b) + self.assertIsInstance(node2, core.DeepTrackNode) + self.assertEqual(node2(), 3) + + # node2 should be a child of both a and b + self.assertIn(node2, a.children) + self.assertIn(node2, b.children) + + # a and b should both be dependencies of node2 + self.assertIn(a, node2.dependencies) + self.assertIn(b, node2.dependencies) + + # Test with one DeepTrackNode, one plain value (multiplication) + node3 = core._create_node_with_operator(operator.mul, a, 2) + self.assertEqual(node3(), 20) + self.assertIsInstance(node3, core.DeepTrackNode) + self.assertIn(node3, a.children) + + # Ensure wrapping of right operand + node4 = core._create_node_with_operator(operator.mul, 3, b) + self.assertEqual(node4(), 21) + self.assertIsInstance(node4, core.DeepTrackNode) + self.assertIn(node4, b.children) + if __name__ == "__main__": unittest.main() From 203709db7416759f42ce9d027f6e071d1806d6ce Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 17:14:22 +0200 Subject: [PATCH 14/54] Update core.py --- deeptrack/backend/core.py | 91 +++++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 41 deletions(-) diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index 6901dd36a..eade1ade5 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -579,9 +579,16 @@ class DeepTrackNode: """Object corresponding to a node in a computation graph. `DeepTrackNode` represents a node within a DeepTrack2 computation graph. - In the DeepTrack2 computation graph, each node can store data and compute - new values based on its dependencies. The value of a node is computed by - calling its `action` method. + Each node can store data and compute new values based on its dependencies. + The value of a node is computed by calling its `action` method. + + Parameters + ---------- + action: Callable or Any, optional + Action to compute this node's value. If not provided, uses a no-op + action (lambda: None). + **kwargs: dict[str, Any] + Additional arguments for subclasses or extended functionality. Attributes ---------- @@ -589,8 +596,11 @@ class DeepTrackNode: Dictionary-like object for storing data, indexed by tuples of integers. children: WeakSet[DeepTrackNode] Nodes that depend on this node (its children, grandchildren, etc.). + This is a weakref.WeakSet, so references are weak and do not prevent + garbage collection of nodes that are no longer used. dependencies: WeakSet[DeepTrackNode] Nodes on which this node depends (its parents, grandparents, etc.). + This is a weakref.WeakSet, for efficient memory management. _action: Callable The function or lambda-function to compute the node value. _accepts_ID: bool @@ -602,49 +612,49 @@ class DeepTrackNode: Methods ------- - action: Property - Gets or sets the computation function for the node. - add_child(child: DeepTrackNode) -> DeepTrackNode - Adds a child node that depends on this node. - Also adds the dependency in the child node on this node. - add_dependency(parent: DeepTrackNode) -> DeepTrackNode - Adds a dependency, making this node depend on the parent node. - It also sets this node as a child of the parent node. - store(data: Any, _ID: tuple[int, ...] = ()) -> DeepTrackNode - Stores computed data for the given `_ID`. - is_valid(_ID: tuple[int, ...] = ()) -> bool - Checks if the data for the given `_ID` is valid. - valid_index(_ID: tuple[int, ...]) -> bool - Checks if the given `_ID` is valid for this node. - invalidate(_ID: tuple[int, ...] = ()) -> DeepTrackNode - Invalidates the data for the given `_ID` and all child nodes. - validate(_ID: tuple[int, ...] = ()) -> DeepTrackNode - Validates the data for the given `_ID`, marking it as up-to-date, but + `action: property` + Get or set the computation function for the node (stored as `_action`). + `add_child(child: DeepTrackNode) -> DeepTrackNode` + Add a child node that depends on this node. + Also add the dependency on this node in the child node. + `add_dependency(parent: DeepTrackNode) -> DeepTrackNode` + Add a dependency, making this node depend on the parent node. + Also set this node as a child of the parent node. + `store(data: Any, _ID: tuple[int, ...] = ()) -> DeepTrackNode` + Store computed data for the given `_ID`. + `is_valid(_ID: tuple[int, ...] = ()) -> bool` + Check whether the data for the given `_ID` is valid. + `valid_index(_ID: tuple[int, ...]) -> bool` + Check whether the given `_ID` is valid for this node. + `invalidate(_ID: tuple[int, ...] = ()) -> DeepTrackNode` + Invalidate the data for the given `_ID` and all child nodes. + `validate(_ID: tuple[int, ...] = ()) -> DeepTrackNode` + Validate the data for the given `_ID`, marking it as up-to-date, but not its children. - update() -> DeepTrackNode - Resets the data. - set_value(value: Any, _ID: tuple[int, ...] = ()) -> DeepTrackNode - Sets a value for the given `_ID`. If the new value differs from the + `update() -> DeepTrackNode` + Reset the data. + `set_value(value: Any, _ID: tuple[int, ...] = ()) -> DeepTrackNode` + Set a value for the given `_ID`. If the new value differs from the current value, the node is invalidated to ensure dependencies are recomputed. - previous(_ID: tuple[int, ...] = ()) -> Any - Returns the previously stored value for the given `_ID` without + `previous(_ID: tuple[int, ...] = ()) -> Any` + Return the previously stored value for the given `_ID` without recomputing it. - recurse_children(memory: set[DeepTrackNode] | None = None) -> set[DeepTrackNode] - Returns all child nodes in the dependency tree rooted at this node. - recurse_dependencies(memory: list[DeepTrackNode] | None = None) -> Iterator[DeepTrackNode] - Yields all nodes that this node depends on, traversing dependencies. - get_citations() -> set[str] - Returns a set of citations for this node and its dependencies. - __call__(_ID: tuple[int, ...] = ()) -> Any - Evaluates the node's computation for the given `_ID`, recomputing if + `recurse_children(memory: set[DeepTrackNode] | None = None) -> set[DeepTrackNode]` + Return all child nodes in the dependency tree rooted at this node. + `recurse_dependencies(memory: list[DeepTrackNode] | None = None) -> Iterator[DeepTrackNode]` + Yield all nodes that this node depends on, traversing dependencies. + `get_citations() -> set[str]` + Return a set of citations for this node and its dependencies. + `__call__(_ID: tuple[int, ...] = ()) -> Any` + Evaluate the node's computation for the given `_ID`, recomputing if necessary. - current_value(_ID: tuple[int, ...] = ()) -> Any - Returns the currently stored value for the given `_ID` without + `current_value(_ID: tuple[int, ...] = ()) -> Any` + Return the currently stored value for the given `_ID` without recomputation. - __hash__() -> int - Returns a unique hash for this node. - __getitem__(idx: Any) -> DeepTrackNode + `__hash__() -> int` + Return a unique hash for this node. + `__getitem__(idx: Any) -> DeepTrackNode` Creates a new node that indexes into this node's computed data. Example @@ -741,7 +751,6 @@ def __init__( action: Callable or Any, optional Action to compute this node's value. If not provided, uses a no-op action (lambda: None). - **kwargs: dict[str, Any] Additional arguments for subclasses or extended functionality. From fa3ba347a43e23ab4ec5dffa1cbeabd69a8d95ed Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 17:19:13 +0200 Subject: [PATCH 15/54] Update core.py --- deeptrack/backend/core.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index eade1ade5..31c629edc 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -576,12 +576,18 @@ def __contains__( class DeepTrackNode: - """Object corresponding to a node in a computation graph. + """Node in a DeepTrack2 computation graph, supporting operator overloading. `DeepTrackNode` represents a node within a DeepTrack2 computation graph. Each node can store data and compute new values based on its dependencies. The value of a node is computed by calling its `action` method. + `DeepTrackNode` supports operator overloading, enabling intuitive + construction of computation graphs using standard Python operators. + For example, nodes can be added, multiplied, subtracted, or compared + directly (e.g., `node1 + node2`, `node1 * 3`, `node1 > node2`), and the + resulting node will represent the composed operation. + Parameters ---------- action: Callable or Any, optional @@ -657,6 +663,26 @@ class DeepTrackNode: `__getitem__(idx: Any) -> DeepTrackNode` Creates a new node that indexes into this node's computed data. + Supported Operators + ------------------- + DeepTrackNode supports the following Python operators: + + Arithmetic: + + Addition (__add__, __radd__) + - Subtraction (__sub__, __rsub__) + * Multiplication (__mul__, __rmul__) + / True division (__truediv__, __rtruediv__) + // Floor division (__floordiv__, __rfloordiv__) + + Comparison: + < Less than (__lt__, __rlt__) + <= Less than or equal (__le__, __rle__) + > Greater than (__gt__, __rgt__) + >= Greater than or equal (__ge__, __rge__) + + Each operation returns a new DeepTrackNode representing the + result of the corresponding operation in the computation graph. + Example ------- Create two `DeepTrackNode` objects: From bf820501b6ede88e7ed7ed99d70eae79371a833e Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 17:55:20 +0200 Subject: [PATCH 16/54] Update core.py --- deeptrack/backend/core.py | 112 ++++++++++++++++++++++++++++++++------ 1 file changed, 94 insertions(+), 18 deletions(-) diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index 31c629edc..a4d7dac46 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -685,41 +685,117 @@ class DeepTrackNode: Example ------- - Create two `DeepTrackNode` objects: + >>> from deeptrack.backend.core import DeepTrackNode + + Create two `DeepTrackNode` objects, one as a parent and one as a child: >>> parent = DeepTrackNode(action=lambda: 10) >>> child = DeepTrackNode(action=lambda _ID=None: parent(_ID) * 2) - - First, establish the dependency between `parent` and `child`: - >>> parent.add_child(child) - Store values in the parent node for specific _IDs: + Store and retrieve data for specific _IDs: >>> parent.store(15, _ID=(0,)) >>> parent.store(20, _ID=(1,)) + >>> parent.current_value((0,)) + 15 + >>> parent.current_value((1,)) + 20 + + Compute and retrieve the value for the child node: - Compute the values for the child node based on these parent values: + >>> child(_ID=(0,)) + 30 + >>> child(_ID=(1,)) + 40 - >>> child_value_0 = child(_ID=(0,)) - >>> child_value_1 = child(_ID=(1,)) - >>> print(child_value_0, child_value_1) - 30 40 + Validation and invalidation: - Invalidate the parent data for a specific _ID: + >>> parent.is_valid((0,)) + True + >>> child.is_valid((0,)) + True >>> parent.invalidate((0,)) - >>> print(parent.is_valid((0,))) + >>> parent.is_valid((0,)) False - >>> print(child.is_valid((0,))) + >>> child.is_valid((0,)) + False + + >>> parent.validate((0,)) + >>> parent.is_valid((0,)) + True + >>> child.is_valid((0,)) + False + + Setting a value and automatic invalidation: + + >>> parent.previous((0,)) + 15 + >>> child((1,)) # Computes and stores the value in child + >>> child.previous((0,)) + 30 + + >>> parent.set_value(42, _ID=(0,)) + >>> parent.current_value((0,)) + 42 + >>> child((0,)) # Recomputes and stores the value in child + >>> child.current_value((0,)) + 84 + + Resetting all data in the dependency tree (recomputation required): + + >>> parent.update() + + Dependency graph traversal (children and dependencies): + + >>> all_children = parent.recurse_children() + >>> all_dependencies = list(child.recurse_dependencies()) + + Operator overloading—arithmetic and comparison: + + >>> node_a = DeepTrackNode(lambda: 5) + >>> node_b = DeepTrackNode(lambda: 3) + + >>> sum_node = node_a + node_b + >>> sum_node() + 8 + + >>> diff_node = node_a - node_b + >>> diff_node() + 2 + + >>> prod_node = node_a * 2 + >>> prod_node() + 10 + + >>> div_node = node_a / node_b + >>> div_node() + 1.666... + + >>> floordiv_node = node_a // node_b + >>> floordiv_node() + 1 + + >>> lt_node = node_a < node_b + >>> lt_node() False - Update the parent value and recompute the child value: + >>> ge_node = node_a >= node_b + >>> ge_node() + True + + Indexing into computed data: + + >>> vector_node = DeepTrackNode(lambda: [10, 20, 30]) + >>> first_element = vector_node[0] + >>> first_element() + 10 + + Citations for a node and its dependencies: - >>> parent.store(25, _ID=(0,)) - >>> child_value_recomputed = child(_ID=(0,)) - >>> print(child_value_recomputed) - 50 + >>> parent.get_citations() # Set of citation strings + {...} """ From 8411100d10ff226aff82a09f0e14be9aa83d1a58 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 18:02:31 +0200 Subject: [PATCH 17/54] Update core.py --- deeptrack/backend/core.py | 140 ++++++++++++++++++++++++++------------ 1 file changed, 95 insertions(+), 45 deletions(-) diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index a4d7dac46..22e9f0d68 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -1,72 +1,122 @@ """Core data structures for DeepTrack2. -This module provides the core DeepTrack2 classes to manage and process data. -In particular, it enables users to: +This module defines the foundational data structures used throughout +DeepTrack2 for constructing, managing, and evaluating computational graphs +with flexible data storage and dependency management. -- Construct flexible and efficient computational pipelines. -- Manage data and dependencies in a hierarchical structure. -- Perform lazy evaluations for performance optimization. +Key Features +------------ +- **Hierarchical Data Management** -Main Features -------------- -- **Data Management** - - `DeepTrackDataObject` and `DeepTrackDataDict` provide tools to store, - validate, and manage data with dependency tracking. They enable nested - data structures and flexible indexing for complex data hierarchies. + Provides validated, hierarchical data containers (`DeepTrackDataObject` + and `DeepTrackDataDict`) for storing data and managing complex, nested + data structures. Supports dependency tracking and flexible indexing. -- **Computational Graphs** - - `DeepTrackNode` forms the backbone of DeepTrack2 computation pipelines, - representing computation nodes in a computation graph. Nodes support lazy - evaluation, dependency tracking, and caching for improved computational - performance. They implement mathematical operators for easy composition - of computational graphs. +- **Computation Graphs with Lazy Evaluation** + + Implements the `DeepTrackNode` class, the core abstraction for nodes in + a computational graph. Supports lazy evaluation, caching, dependency + tracking, and operator overloading for intuitive composition of complex + computational pipelines. -- **Citations** +- **Citation Support** - Supports citing the relevant publication to ensure proper attribution - (e.g., `Midtvedt et al., 2021`). + Provides citation metadata to ensure proper academic attribution for work + built on DeepTrack2. Module Structure ------------------ -Data Containers: +---------------- +Classes: -- `DeepTrackDataObject`: Basic data container with validation status. +- `DeepTrackDataObject`: Basic container for data with validation status. - A basic container for data with validation status. - -- `DeepTrackDataDict`: Dictionary to store multiple data with validation. + Simple data container that stores data and tracks its validity + (valid/invalid). - A data container to store multiple data objects (`DeepTrackDataObject`) - indexed by unique access _IDs (consisting of tuples of integers), enabling - nested data storage. +- `DeepTrackDataDict`: Hierarchical dictionary for multiple data objects. -Computation Nodes: + Stores multiple `DeepTrackDataObject` instances indexed by tuples of + integers, enabling the creation of flexible, nested data hierarchies. -- `DeepTrackNode`: Node in a computation graph. - - Represents a node in a computation graph, capable of lazy evaluation, - caching, and dependency management. +- `DeepTrackNode`: Node in a computation graph with operator overloading. -Example -------- -Create two `DeepTrackNode` objects: + Represents a node in a computation graph, capable of storing and + computing values based on dependencies, with full support for lazy + evaluation, dependency tracking, and operator overloading. ->>> parent = DeepTrackNode() ->>> child = DeepTrackNode(lambda: 2 * parent()) ->>> parent.add_child(child) +Functions: + +- `_equivalent(a, b)` + + def _equivalent(a: Any, b: Any) -> bool + + Determines whether two objects should be considered equivalent, + according to DeepTrack2's internal rules (identity, empty lists, etc). + +- `_create_node_with_operator(op, a, b)` + + def _create_node_with_operator( + op: Callable, + a: Any, + b: Any, + ) -> DeepTrackNode + + Internal helper to create a new computation node by applying a + specified operator to two operands, establishing correct graph + relationships and supporting operator overloading. -Set the value of the parent: +Attributes: +- `CITATION_MIDTVEDT2021QUANTITATIVE`: str + + BibTeX citation for the original DeepTrack2 publication. + +Examples +-------- +>>> import deeptrack as dt + +Create a simple computational pipeline using DeepTrack2 nodes: + +>>> parent = dt.DeepTrackNode() +>>> child = dt.DeepTrackNode(lambda: 2 * parent()) +>>> parent.add_child(child) >>> parent.store(5) +>>> child() # Compute child +10 + +Operator overloading for computation nodes: -And print the value of the child: +>>> a = dt.DeepTrackNode(lambda: 3) +>>> b = dt.DeepTrackNode(lambda: 4) +>>> sum_node = a + b +>>> sum_node() +7 ->>> print(child()) # Output: 10 +Create and use a hierarchical data dictionary: + +>>> data_dict = dt.DeepTrackDataDict() +>>> data_dict.create_index((0, 1)) +>>> data_dict[(0, 1)].store("Example data") +>>> data_dict[(0, 1)].current_value() +'Example data' + +Validate and invalidate a data object: + +>>> data_obj = dt.DeepTrackDataObject() +>>> data_obj.is_valid() +False + +>>> data_obj.store(42) +>>> data_obj.is_valid() +True + +>>> data_obj.invalidate() +>>> data_obj.is_valid() +False """ + from __future__ import annotations import operator # Operator overloading for computation nodes. From 530ad755cf3382b7ffae55ceea0fee1c12a2521e Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 18:32:24 +0200 Subject: [PATCH 18/54] Update DTAT399F_backend._config.ipynb --- tutorials/3-advanced-topics/DTAT399F_backend._config.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/3-advanced-topics/DTAT399F_backend._config.ipynb b/tutorials/3-advanced-topics/DTAT399F_backend._config.ipynb index 4dea5ead9..d38e7bd62 100644 --- a/tutorials/3-advanced-topics/DTAT399F_backend._config.ipynb +++ b/tutorials/3-advanced-topics/DTAT399F_backend._config.ipynb @@ -22,7 +22,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This advanced tutorial introduces the backend._config module." + "This advanced tutorial introduces the `backend._config.py` module." ] }, { From 0113458341a36b8a269d5937eae4a32c55bc67e4 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 19:06:58 +0200 Subject: [PATCH 19/54] Update DTAT399A_backend.core.ipynb --- .../DTAT399A_backend.core.ipynb | 1133 ++++++++++++++--- 1 file changed, 989 insertions(+), 144 deletions(-) diff --git a/tutorials/3-advanced-topics/DTAT399A_backend.core.ipynb b/tutorials/3-advanced-topics/DTAT399A_backend.core.ipynb index 671679e6a..9c8174594 100644 --- a/tutorials/3-advanced-topics/DTAT399A_backend.core.ipynb +++ b/tutorials/3-advanced-topics/DTAT399A_backend.core.ipynb @@ -4,14 +4,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# deeptrack.backend.core\n", + "# DTAT399A. deeptrack.backend.core\n", "\n", "\"Open" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -22,259 +22,1104 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This advanced tutorial introduces the backend.core module." + "This advanced tutorial introduces the `backend.core.py` module." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 1. What is `core`?\n", + "## 1. What is `core.py`?\n", "\n", - "The `core` module provides fundamental utilities and functions to manage and process data on a low level.\n", + "The `core.py` module is DeepTrack2’s foundation for data management and computational graph construction.\n", "\n", - "In particular it provide tools to store, validate, and manage data and computational nodes with dependency tracking.\n" + "It provides the fundamental classes and abstractions that underpin all DeepTrack2 pipelines, enabling flexible, efficient, and traceable computation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 2. Basic Node Usage with Parent-Child Dependency" + "The key roles of `core.py` are:\n", + "\n", + "- *Data Object Abstractions:**\n", + " Defines simple and validated data containers (`DeepTrackDataObject` and `DeepTrackDataDict`) that store, index, and validate arbitrary data. These classes enable hierarchical and multidimensional organization of complex datasets.\n", + "\n", + "- **Computation Graph Nodes:**\n", + " Implements the `DeepTrackNode` class, which represents a node in a computational graph. Each node can compute, store, and cache values, and can express dependencies on other nodes—enabling the creation of highly flexible and efficient processing pipelines.\n", + "\n", + "- **Lazy Evaluation & Caching:**\n", + " Supports on-demand computation and result caching through lazy evaluation. Nodes only compute their value when required, and cache results for future use until dependencies are invalidated.\n", + "\n", + "- **Operator Overloading for Pipelines:**\n", + " Enables intuitive construction of complex computational graphs using standard Python arithmetic and comparison operators (e.g., +, *, <). This makes pipeline composition both expressive and readable.\n", + "\n", + "- **Dependency Tracking & Propagation:**\n", + " Tracks parent-child and dependency relationships among nodes, so that changes or invalidations automatically propagate through the graph—guaranteeing computational consistency.\n", + "\n", + "- **Citation and Provenance:**\n", + " Integrates citation metadata, ensuring proper academic attribution for work that builds upon DeepTrack2’s infrastructure." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Using Nodes with Parent-Child Dependencies\n", + "\n", + "In DeepTrack2, nodes represent computational units that can be flexibly linked into graphs by defining dependencies. This allows you to build modular, traceable, and efficient pipelines where changes automatically propagate through the graph.\n", + "\n", + "Below we show how to set up parent-child relationships between nodes, store and compute data, and propagate invalidation when the upstream data changes." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from deeptrack.backend.core import DeepTrackNode" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.1. Creating Parent and Child Nodes\n", + "\n", + "We create a parent node and a child node whose value is always twice that of its parent." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Create parent and child nodes\n", + "parent = DeepTrackNode(action=lambda: 10)\n", + "child = DeepTrackNode(action=lambda _ID=None: parent(_ID) * 2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.2. Establishing Parent-Child Dependency\n", + "\n", + "We link the parent and child so that the child automatically tracks changes in the parent. In this way, the parent is updated or invalidated, this relationship ensures that the child is also kept up to date." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Establish parent-child dependency\n", + "parent.add_child(child)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.3. Storing Values and Computing Results\n", + "\n", + "Let’s assign different values to the parent for different data indices (`_ID`)." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Store values in parent node associated to different _IDs\n", + "parent.store(15, _ID=(0,))\n", + "parent.store(20, _ID=(1,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.4. Computing and Accessing Child Values\n", + "\n", + "The child node computes its value based on the current value of the parent for each index." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "30" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "child(_ID=(0,))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "40" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "child(_ID=(1,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**NOTE:** Calling `child(_ID=(0,))` computes the value if needed, and caches it.\n", + "On the other hand, calling `child.current_value((0,))` retrieves the currently cached value without recomputing.\n", + "Therefore, you can access the last computed value for a specific index using `.current_value(_ID)`.\n", + "If the value hasn’t yet been computed or stored, this will raise a `KeyError`." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "30" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Retrieve the cached value without recomputing\n", + "child.current_value((0,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**NOTE:** The `.previous(_ID)` method is similar to `.current_value(_ID)`, but will return an empty list if the index is not valid, instead of raising an error.\n", + "This is useful for checking if an index is valid without triggering a computation." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "30" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "child.previous((0,))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "child.previous((42, 43)) # (42, 43) is not a valid index" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.5. Validation and Invalidation\n", + "\n", + "When you invalidate the parent for a particular _ID, the child’s value for that _ID will also be marked as invalid (since it depends on the parent). This ensures that downstream computations are never out of sync." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Invalidate parent data for a given ID.\n", + "parent.invalidate((0,))\n", + "parent.is_valid((0,))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "child.is_valid((0,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.6. Updating and Recomputing Values\n", + "\n", + "After invalidation, if we update the parent and request the child’s value again, it will be recomputed as needed." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "50" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Update the parent value and recompute the child value\n", + "parent.store(25, _ID=(0,))\n", + "child((0,))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "parent.is_valid((0,))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "child.is_valid((0,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.7 Setting a Value and Automatic Invalidation\n", + "\n", + "You can force a value into a node’s storage with `.set_value(value, _ID)`. If the new value is different, dependent nodes will be invalidated." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "100" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "parent.set_value(100, _ID=(1,))\n", + "parent.current_value((1,))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "parent.is_valid((1,))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "child.is_valid((1,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Lazy Evaluation and Caching\n", + "\n", + "A powerful feature of DeepTrack2 nodes is lazy evaluation: the node’s value is only computed when it is needed, and the result is cached until the node (or its dependencies) is invalidated. This avoids redundant computations and ensures high efficiency, especially in large graphs.\n", + "\n", + "In this example, we’ll use a global counter to demonstrate when the node’s computation actually happens." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.1 Defining a Node with a Side Effect\n", + "\n", + "First, we define a calculation function that increments a global counter each time it is called. This allows us to see exactly how many times the node’s computation is performed." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# Create counter node with side effect\n", + "call_count = 0\n", + "def calculation():\n", + " global call_count\n", + " call_count += 1\n", + " return 10\n", + "\n", + "node = DeepTrackNode(calculation)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.2 Demonstrating Lazy Evaluation\n", + "\n", + "Let’s see what happens when we call the node multiple times:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# First call computes the value (calls the function)\n", + "node()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "call_count" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Second call uses the cached value (no additional computation)\n", + "node()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "call_count" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.3 Invalidation Forces Recalculation\n", + "\n", + "If we invalidate the node, the cache is cleared and the next call will recompute the value:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Invalidate the node and call again (forces recomputation)\n", + "node.invalidate()\n", + "node()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "call_count" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Invalidate and call again\n", + "node.invalidate()\n", + "node()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "call_count" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Data Management with IDs\n", + "\n", + "In DeepTrack2, the `DeepTrackDataDict` class provides an efficient, validated way to manage multiple data objects, each indexed by a unique tuple of integers.\n", + "\n", + "This is especially useful for working with multidimensional datasets, or for mapping results to experiment or batch indices." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4.1. Creating and Indexing Data Objects\n", + "\n", + "You can create entries with arbitrary integer index tuples, just like keys in a nested dictionary." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "from deeptrack.backend.core import DeepTrackDataDict\n", + "\n", + "data_dict = DeepTrackDataDict()\n", + "\n", + "# Create listings with unique indices.\n", + "data_dict.create_index((0, 0))\n", + "data_dict.create_index((0, 1))\n", + "data_dict.create_index((1, 0))\n", + "data_dict.create_index((1, 1))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4.2. Storing and Retrieving Data\n", + "\n", + "Each index corresponds to a `DeepTrackDataObject`, where you can store and retrieve data.\n", + "This is similar to using a multidimensional dictionary." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "# Store some data for the indices.\n", + "data_dict[(0, 0)].store(\"Cat\")\n", + "data_dict[(0, 1)].store(\"Dog\")\n", + "data_dict[(1, 0)].store(\"Mouse\")\n", + "data_dict[(1, 1)].store(\"Bird\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4.3. Accessing Data by ID\n", + "\n", + "You can access data by its full index, or get a dictionary of all entries with a common prefix." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 30, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "30 40\n", - "False\n", - "False\n", - "50\n" - ] + "data": { + "text/plain": [ + "'Cat'" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "from deeptrack.backend.core import DeepTrackNode\n", - "\n", - "parent = DeepTrackNode(action=lambda: 10)\n", - "child = DeepTrackNode(action=lambda _ID=None: parent(_ID) * 2)\n", - "\n", - "# Establish parent-child dependency.\n", - "parent.add_child(child)\n", - "\n", - "# Store values.\n", - "parent.store(15, _ID=(0,))\n", - "parent.store(20, _ID=(1,))\n", - "\n", - "# Compute values based on parent values.\n", - "child_value_0 = child(_ID=(0,))\n", - "child_value_1 = child(_ID=(1,))\n", - "print(child_value_0, child_value_1)\n", - "\n", - "# Invalidate parent data for a given ID.\n", - "parent.invalidate((0,))\n", - "print(parent.is_valid((0,)))\n", - "\n", - "# Update the parent value and recompute the child value:\n", - "print(child.is_valid((0,)))\n", - "parent.store(25, _ID=(0,))\n", - "child_value_recomputed = child(_ID=(0,))\n", - "print(child_value_recomputed)" + "# Retrieve and print values for specific indices.\n", + "data_dict[(0, 0)].current_value()" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 31, "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Bird'" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "## 3. Lazy evaluation and Caching\n", - "Here we add a function to a `DeepTrackNode` which retuns a constant value and updates a global counter variable when called." + "data_dict[(1, 1)].current_value()" ] }, { "cell_type": "code", - "execution_count": 95, + "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10 1\n", - "10 2\n", - "10 3\n" + "{(0, 0): , (0, 1): }\n" ] } ], "source": [ - "# Create counter node with side effect\n", - "call_count = 0\n", - "def calculation():\n", - " global call_count\n", - " call_count += 1\n", - " return 10\n", + "# Retrieve all entries whose indices start with (0,)\n", + "print(data_dict[(0, )])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Operator Overloading and Pipeline Composition\n", "\n", - "node = DeepTrackNode(calculation)\n", + "A unique and powerful feature of `DeepTrackNode` is its support for operator overloading. This allows you to build complex computational pipelines by composing nodes using familiar arithmetic and comparison operators, making your code both expressive and readable.\n", "\n", - "# First call computes value.\n", - "print(node(), call_count) \n", + "Every operator creates a new node that, when called, evaluates its operands, applies the operator, and caches the result. Dependency relationships are automatically tracked, so invalidating an operand will invalidate any composed nodes as well." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5.1. Combining Nodes with Arithmetic Operators\n", "\n", - "# Subsequent call uses cached value.\n", - "node.invalidate()\n", - "print(node(), call_count) \n", + "You can add, subtract, multiply, or divide nodes just like numbers. The result is always a new `DeepTrackNode` that represents the composed computation." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "from deeptrack.backend.core import DeepTrackNode\n", "\n", - "# Invalidate and call again.\n", - "node.invalidate()\n", - "print(node(), call_count) " + "a = DeepTrackNode(lambda: 5)\n", + "b = DeepTrackNode(lambda: 3)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "8" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sum_node = a + b\n", + "sum_node()" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "diff_node = a - b\n", + "diff_node()" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prod_node = a * 2\n", + "prod_node()" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.6666666666666667" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "div_node = a / b\n", + "div_node()" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "floordiv_node = a // b\n", + "floordiv_node()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 4. Data Management with IDs\n", + "### 5.2. Chaining and Nesting Operators\n", "\n", - "Map IDs to stored `DeepTrackData` objects lika a dictionary." + "You can compose pipelines of arbitrary depth and complexity:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 39, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Cat\n", - "Bird\n", - "{(0, 0): , (0, 1): }\n" - ] + "data": { + "text/plain": [ + "4.0" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "from deeptrack.backend.core import DeepTrackDataDict\n", - "\n", - "data_dict = DeepTrackDataDict()\n", - "\n", - "# Create listings with unique indices.\n", - "data_dict.create_index((0, 0))\n", - "data_dict.create_index((0, 1))\n", - "data_dict.create_index((1, 0))\n", - "data_dict.create_index((1, 1))\n", - "\n", - "# Store some data for the indices.\n", - "data_dict[(0, 0)].store(\"Cat\")\n", - "data_dict[(0, 1)].store(\"Dog\")\n", - "data_dict[(1, 0)].store(\"Mouse\")\n", - "data_dict[(1, 1)].store(\"Bird\")\n", - "\n", - "# Print the indices.\n", - "print(data_dict[(0, 0)].current_value())\n", - "print(data_dict[(1, 1)].current_value())\n", - "print(data_dict[(0, )])" + "complex_node = ((a + b) * 2) / (b + 1)\n", + "complex_node()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 5. Propagating operators\n", - "Nodes can also be used as simple handles for functions." + "### 5.3. Comparison Operators for Graphs\n", + "\n", + "Comparison operators also work on nodes, returning new nodes that compute boolean results:" ] }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 40, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "16\n", - "60\n" - ] + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "a = DeepTrackNode(lambda: 5 + 5)\n", - "b = DeepTrackNode(lambda: 3 + 3)\n", - "\n", - "sum_node = a + b\n", - "product_node = a * b\n", - "\n", - "print(sum_node())\n", - "print(product_node())" + "lt_node = a < b\n", + "lt_node()" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ge_node = a >= b\n", + "ge_node()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 6. Validation control\n", - "Validate or invalidate nodes manually to enable/disable storing data." + "### 5.4. Mixing Nodes and Constants\n", + "\n", + "You can mix DeepTrackNode instances and regular numbers:" ] }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 42, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "100\n", - "True\n", - "100\n", - "False\n", - "42\n" - ] + "data": { + "text/plain": [ + "12" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "node = DeepTrackNode(lambda: 42)\n", - "node.store(100)\n", - "\n", - "print(node())\n", - "\n", - "# Validate.\n", - "node.validate()\n", - "print(node.is_valid())\n", - "print(node()) \n", - "\n", - "# Invalidate.\n", - "node.invalidate()\n", - "print(node.is_valid())\n", - "print(node())" + "sum_with_constant = a + 7\n", + "sum_with_constant()" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "9" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mult_with_constant = 3 * b\n", + "mult_with_constant()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 7. Get Citations\n", - "The `DeepTrackNode` class can also be used to obtain citations." + "## 6. Getting Citations\n", + "\n", + "The `DeepTrackNode` class can also be used to obtain the relevant citations." ] }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'\\n@article{Midtvet2021DeepTrack,\\n author = {Midtvedt,Benjamin and \\n Helgadottir,Saga and \\n Argun,Aykut and \\n Pineda,Jesús and \\n Midtvedt,Daniel and \\n Volpe,Giovanni},\\n title = {Quantitative digital microscopy with deep learning},\\n journal = {Applied Physics Reviews},\\n volume = {8},\\n number = {1},\\n pages = {011310},\\n year = {2021},\\n doi = {10.1063/5.0034891}\\n}\\n'}" + "{'\\n@article{Midtvet2021Quantitative,\\n author = {Midtvedt, Benjamin and Helgadottir, Saga and Argun, Aykut and \\n Pineda, Jesús and Midtvedt, Daniel and Volpe, Giovanni},\\n title = {Quantitative digital microscopy with deep learning},\\n journal = {Applied Physics Reviews},\\n volume = {8},\\n number = {1},\\n pages = {011310},\\n year = {2021},\\n doi = {10.1063/5.0034891}\\n}\\n'}" ] }, - "execution_count": 101, + "execution_count": 44, "metadata": {}, "output_type": "execute_result" } @@ -286,7 +1131,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "py_env_book", "language": "python", "name": "python3" }, @@ -300,7 +1145,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.15" } }, "nbformat": 4, From 1801ec9d3925d4e1eb3c0896d9337aadb99d248f Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 19:12:45 +0200 Subject: [PATCH 20/54] Update pint_definition.py --- deeptrack/backend/pint_definition.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/deeptrack/backend/pint_definition.py b/deeptrack/backend/pint_definition.py index a804a7dc9..af9903da8 100644 --- a/deeptrack/backend/pint_definition.py +++ b/deeptrack/backend/pint_definition.py @@ -56,6 +56,12 @@ """ +__all__ = [ + "pint_constants", + "pint_definitions", +] + + pint_constants = """ # Default Pint constants definition file # Based on the International System of Units From ed99efa0a7761615c13ad4a5309993f713c237ae Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 21:05:29 +0200 Subject: [PATCH 21/54] u --- deeptrack/backend/core.py | 8 ++-- deeptrack/tests/backend/test_core.py | 2 +- .../DTAT399A_backend.core.ipynb | 48 ------------------- 3 files changed, 4 insertions(+), 54 deletions(-) diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index 22e9f0d68..38e883b57 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -693,9 +693,6 @@ class DeepTrackNode: Set a value for the given `_ID`. If the new value differs from the current value, the node is invalidated to ensure dependencies are recomputed. - `previous(_ID: tuple[int, ...] = ()) -> Any` - Return the previously stored value for the given `_ID` without - recomputing it. `recurse_children(memory: set[DeepTrackNode] | None = None) -> set[DeepTrackNode]` Return all child nodes in the dependency tree rooted at this node. `recurse_dependencies(memory: list[DeepTrackNode] | None = None) -> Iterator[DeepTrackNode]` @@ -780,10 +777,10 @@ class DeepTrackNode: Setting a value and automatic invalidation: - >>> parent.previous((0,)) + >>> parent.current_value((0,)) 15 >>> child((1,)) # Computes and stores the value in child - >>> child.previous((0,)) + >>> child.current_value((0,)) 30 >>> parent.set_value(42, _ID=(0,)) @@ -1178,6 +1175,7 @@ def set_value( return self + # TODO: The previous() method should be moved into SequentialProperty def previous( self: DeepTrackNode, _ID: tuple[int, ...] = (), diff --git a/deeptrack/tests/backend/test_core.py b/deeptrack/tests/backend/test_core.py index b1b1292f9..b143f6063 100644 --- a/deeptrack/tests/backend/test_core.py +++ b/deeptrack/tests/backend/test_core.py @@ -208,7 +208,7 @@ def test_DeepTrackNode_single_id(self): # Retrieves the values stored in children and parents. for id, value in enumerate(range(10)): self.assertEqual(child(_ID=(id,)), value * 2) - self.assertEqual(parent.previous((id,)), value) + self.assertEqual(parent.current_value((id,)), value) def test_DeepTrackNode_nested_ids(self): # Test nested IDs for parent-child relationships. diff --git a/tutorials/3-advanced-topics/DTAT399A_backend.core.ipynb b/tutorials/3-advanced-topics/DTAT399A_backend.core.ipynb index 9c8174594..235199fc7 100644 --- a/tutorials/3-advanced-topics/DTAT399A_backend.core.ipynb +++ b/tutorials/3-advanced-topics/DTAT399A_backend.core.ipynb @@ -242,54 +242,6 @@ "child.current_value((0,))" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**NOTE:** The `.previous(_ID)` method is similar to `.current_value(_ID)`, but will return an empty list if the index is not valid, instead of raising an error.\n", - "This is useful for checking if an index is valid without triggering a computation." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "30" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "child.previous((0,))" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "child.previous((42, 43)) # (42, 43) is not a valid index" - ] - }, { "cell_type": "markdown", "metadata": {}, From 22264a215c7872582c3d458e9d42815a204d6ad3 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Thu, 29 May 2025 23:49:50 +0200 Subject: [PATCH 22/54] Update utils.py --- deeptrack/utils.py | 67 ++++++++++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 23 deletions(-) diff --git a/deeptrack/utils.py b/deeptrack/utils.py index 74804b2b9..2519ec64d 100644 --- a/deeptrack/utils.py +++ b/deeptrack/utils.py @@ -15,11 +15,25 @@ """ +from __future__ import annotations + import inspect -from typing import Any, Callable, List +from typing import Any, Callable + + +__all__ = [ + "hasmethod", + "as_list", + "get_kwarg_names", + "kwarg_has_default", + "safe_call", +] -def hasmethod(obj: Any, method_name: str) -> bool: +def hasmethod( + obj: Any, + method_name: str, +) -> bool: """Check if an object has a callable method named `method_name`. Returns `True` if the object has a field named `method_name` that is @@ -27,9 +41,9 @@ def hasmethod(obj: Any, method_name: str) -> bool: Parameters ---------- - obj : Any + obj: Any The object to inspect. - method_name : str + method_name: str The name of the method to look for. Returns @@ -52,7 +66,7 @@ def as_list(obj: any) -> list: Parameters ---------- - obj : Any + obj: Any The object to be converted or wrapped in a list. Returns @@ -68,7 +82,7 @@ def as_list(obj: any) -> list: return [obj] -def get_kwarg_names(function: Callable) -> List[str]: +def get_kwarg_names(function: Callable) -> list[str]: """Retrieve the names of the keyword arguments accepted by a function. It retrieves the names of the keyword arguments accepted by `function` as a @@ -76,7 +90,7 @@ def get_kwarg_names(function: Callable) -> List[str]: Parameters ---------- - function : Callable + function: Callable The function whose keyword argument names are to be retrieved. Returns @@ -97,14 +111,17 @@ def get_kwarg_names(function: Callable) -> List[str]: return argspec.args or [] -def kwarg_has_default(function: Callable, argument: str) -> bool: +def kwarg_has_default( + function: Callable, + argument: str, +) -> bool: """Check if a specific argument of a function has a default value. Parameters ---------- - function : Callable + function: Callable The function to inspect. - argument : str + argument: str Name of the argument to check. Returns @@ -124,7 +141,11 @@ def kwarg_has_default(function: Callable, argument: str) -> bool: return len(args) - args.index(argument) <= len(defaults) -def safe_call(function, positional_args=[], **kwargs) -> Any: +def safe_call( + function: Callable[..., Any], + positional_args: list[Any] | None = None, + **kwargs: Any, +) -> Any: """Calls a function with valid arguments from a dictionary of arguments. It filters `kwargs` to include only arguments accepted by the function, @@ -133,26 +154,26 @@ def safe_call(function, positional_args=[], **kwargs) -> Any: Parameters ---------- - function : Callable + function: Callable[..., Any] The function to call. - positional_args : list, optional - List of positional arguments to pass to the function. - kwargs : dict + positional_args: list[Any] | None, optional + List of positional arguments to pass to the function. Defaults to None. + **kwargs: dict[str, Any] Dictionary of keyword arguments to filter and pass. - + Returns ------- Any The result of calling the function with the filtered arguments. - + """ - keys = get_kwarg_names(function) + if positional_args is None: + positional_args = [] # Filter kwargs to include only keys present in the function's signature. - input_arguments = {} - for key in keys: - if key in kwargs: - input_arguments[key] = kwargs[key] + input_arguments = { + key: kwargs[key] for key in get_kwarg_names(function) if key in kwargs + } - return function(*positional_args, **input_arguments) \ No newline at end of file + return function(*positional_args, **input_arguments) From ececad767b1c0c73e2436703a8ba7f3201eed60c Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Fri, 30 May 2025 00:03:24 +0200 Subject: [PATCH 23/54] Update utils.py --- deeptrack/utils.py | 95 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 78 insertions(+), 17 deletions(-) diff --git a/deeptrack/utils.py b/deeptrack/utils.py index 2519ec64d..658f76335 100644 --- a/deeptrack/utils.py +++ b/deeptrack/utils.py @@ -1,17 +1,75 @@ -"""Utility functions. +"""Utility functions for argument handling and signature inspection. -This module defines utility functions that enhance code readability, -streamline common operations, and ensure type and argument consistency. +This module provides utility functions to enhance code readability, +streamline common operations, and ensure type and argument consistency +when working with functions, methods, and callables in Python. + +Key Features +------------ +- **Method Detection** + + Check if an object has a callable method with a given name. + +- **List Conversion** + + Ensure that any input is represented as a list. + +- **Signature Inspection** + + Retrieve the names of arguments a function accepts, and check for + default values. + +- **Safe Function Calling** + + Call a function by passing only arguments accepted by its signature. Module Structure ---------------- Functions: -- `hasmethod`: Checks if an object has a callable method named `method_name`. -- `as_list`: Ensures that the input is a list. -- `get_kwarg_names`: Retrieves keyword argument names accepted by a function. -- `kwarg_has_default`: Checks if a function argument has a default value. -- `safe_call`: Calls a function, passing only valid arguments. +- `hasmethod(obj, method_name)` + + def hasmethod( + obj: Any, + method_name: str, + ) -> bool + + Check if an object has a callable method named `method_name`. + +- `as_list(obj)` + + def as_list(obj: Any) -> list[Any] + + Ensure that the input is a list, wrapping if necessary. + +- `get_kwarg_names(function)` + + def get_kwarg_names(function: Callable[..., Any]) -> list[str] + + Retrieve the names of the keyword arguments accepted by a function. + +- `kwarg_has_default(function, argument)` + + def kwarg_has_default( + function: Callable[..., Any], + argument: str, + ) -> bool + + Check if a specific argument of a function has a default value. + +- `safe_call(function, positional_args=None, **kwargs)` + + def safe_call( + function: Callable[..., Any], + positional_args: list[Any] | None = None, + **kwargs: Any, + ) -> Any + + Call a function, passing only valid arguments from a dictionary. + +Examples +-------- +TODO """ @@ -58,11 +116,11 @@ def hasmethod( and callable(getattr(obj, method_name, None))) -def as_list(obj: any) -> list: +def as_list(obj: Any) -> list[Any]: """Ensure that the input is a list. - Converts the input to a list if it is iterable; otherwise, it wraps it in a - list. + Converts the input to a list if it is iterable and not a string or bytes; + otherwise, it wraps it in a list. Parameters ---------- @@ -71,18 +129,21 @@ def as_list(obj: any) -> list: Returns ------- - list + list[Any] The input object as a list. """ + if isinstance(obj, (str, bytes)): + return [obj] + try: return list(obj) except TypeError: return [obj] -def get_kwarg_names(function: Callable) -> list[str]: +def get_kwarg_names(function: Callable[..., Any]) -> list[str]: """Retrieve the names of the keyword arguments accepted by a function. It retrieves the names of the keyword arguments accepted by `function` as a @@ -90,7 +151,7 @@ def get_kwarg_names(function: Callable) -> list[str]: Parameters ---------- - function: Callable + function: Callable[..., Any] The function whose keyword argument names are to be retrieved. Returns @@ -112,14 +173,14 @@ def get_kwarg_names(function: Callable) -> list[str]: def kwarg_has_default( - function: Callable, + function: Callable[..., Any], argument: str, ) -> bool: """Check if a specific argument of a function has a default value. Parameters ---------- - function: Callable + function: Callable[..., Any] The function to inspect. argument: str Name of the argument to check. @@ -130,7 +191,7 @@ def kwarg_has_default( True if the specified argument has a default value. """ - + args = get_kwarg_names(function) if argument not in args: From d9d9d007977bb81ee6871f8ecf6cc36c530b7803 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Fri, 30 May 2025 00:08:38 +0200 Subject: [PATCH 24/54] Update utils.py --- deeptrack/utils.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/deeptrack/utils.py b/deeptrack/utils.py index 658f76335..2da20c586 100644 --- a/deeptrack/utils.py +++ b/deeptrack/utils.py @@ -94,7 +94,7 @@ def hasmethod( ) -> bool: """Check if an object has a callable method named `method_name`. - Returns `True` if the object has a field named `method_name` that is + It returns `True` if the object has a field named `method_name` that is callable. Otherwise, returns `False`. Parameters @@ -110,6 +110,10 @@ def hasmethod( True if the object has an attribute named `method_name` that is callable. + Examples + -------- + TODO + """ return (hasattr(obj, method_name) @@ -119,8 +123,8 @@ def hasmethod( def as_list(obj: Any) -> list[Any]: """Ensure that the input is a list. - Converts the input to a list if it is iterable and not a string or bytes; - otherwise, it wraps it in a list. + It converts the input to a list if it is iterable and not a string or + bytes; otherwise, it wraps it in a list. Parameters ---------- @@ -132,6 +136,10 @@ def as_list(obj: Any) -> list[Any]: list[Any] The input object as a list. + Examples + -------- + TODO + """ if isinstance(obj, (str, bytes)): @@ -156,9 +164,13 @@ def get_kwarg_names(function: Callable[..., Any]) -> list[str]: Returns ------- - List[str] + list[str] A list of names of keyword arguments the function accepts. + Examples + -------- + TODO + """ try: @@ -190,6 +202,10 @@ def kwarg_has_default( bool True if the specified argument has a default value. + Examples + -------- + TODO + """ args = get_kwarg_names(function) @@ -227,6 +243,10 @@ def safe_call( Any The result of calling the function with the filtered arguments. + Examples + -------- + TODO + """ if positional_args is None: From 28d29f0a08f4274280e24397fe7896e174b64efe Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Fri, 30 May 2025 00:15:13 +0200 Subject: [PATCH 25/54] Update utils.py --- deeptrack/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deeptrack/utils.py b/deeptrack/utils.py index 2da20c586..0c5fb819c 100644 --- a/deeptrack/utils.py +++ b/deeptrack/utils.py @@ -126,6 +126,11 @@ def as_list(obj: Any) -> list[Any]: It converts the input to a list if it is iterable and not a string or bytes; otherwise, it wraps it in a list. + Note: If `obj` is a PyTorch Tensor, this function will return a list of its + elements along the first dimension (e.g., for a 2D tensor, the result + will be a list of 1D tensors). If you want to wrap the entire tensor in a + list, use `[obj]` explicitly. + Parameters ---------- obj: Any From 2b92865cda0fb1ae51749b2ba5ae3b93fce682d2 Mon Sep 17 00:00:00 2001 From: Giovanni Volpe Date: Fri, 30 May 2025 00:34:47 +0200 Subject: [PATCH 26/54] Update test_utils.py --- deeptrack/tests/test_utils.py | 170 ++++++++++++++++++++-------------- 1 file changed, 102 insertions(+), 68 deletions(-) diff --git a/deeptrack/tests/test_utils.py b/deeptrack/tests/test_utils.py index d476c4d5d..ba2deff9b 100644 --- a/deeptrack/tests/test_utils.py +++ b/deeptrack/tests/test_utils.py @@ -8,98 +8,132 @@ import unittest +import deeptrack as dt +import numpy as np + from deeptrack import utils +class DummyClass: + def method(self): pass + def __len__(self): return 42 + + class TestUtils(unittest.TestCase): def test_hasmethod(self): self.assertTrue(utils.hasmethod(utils, "hasmethod")) - self.assertFalse( - utils.hasmethod(utils, "this_is_definetely_not_a_method_of_utils") - ) + self.assertFalse(utils.hasmethod(utils, "not_a_method")) + self.assertTrue(utils.hasmethod(DummyClass, "method")) + self.assertFalse(utils.hasmethod(DummyClass, "not_real")) + self.assertTrue(utils.hasmethod(DummyClass(), "method")) + self.assertTrue(utils.hasmethod(DummyClass(), "__len__")) + self.assertFalse(utils.hasmethod(123, "foo")) # int has no foo - def test_as_list(self): - obj = 1 - self.assertEqual(utils.as_list(obj), [obj]) + # Built-in edge cases + self.assertTrue(utils.hasmethod([], "append")) + self.assertFalse(utils.hasmethod([], "not_a_real_method")) - list_obj = [1, 2, 3] - self.assertEqual(utils.as_list(list_obj), list_obj) + def test_as_list(self): + # Scalars + self.assertEqual(utils.as_list(1), [1]) + self.assertEqual(utils.as_list(None), [None]) + self.assertEqual(utils.as_list(3.14), [3.14]) + + # Containers + self.assertEqual(utils.as_list([1, 2]), [1, 2]) + self.assertEqual(utils.as_list((1, 2)), [1, 2]) + self.assertEqual(sorted(utils.as_list({1, 2})), [1, 2]) + + # Generator + gen = (i for i in range(2)) + self.assertEqual(utils.as_list(gen), [0, 1]) + + # Strings and bytes + self.assertEqual(utils.as_list("abc"), ["abc"]) + self.assertEqual(utils.as_list(b"123"), [b"123"]) + + # Numpy array + arr = np.array([1, 2, 3]) + result = utils.as_list(arr) + self.assertTrue(isinstance(result, list)) + self.assertTrue(all(isinstance(x, (int, np.generic)) for x in result)) + + if dt.TORCH_AVAILABLE: + import torch + + tensor = torch.tensor([[1, 2], [3, 4]]) + result = utils.as_list(tensor) + + # By default, this will be [tensor([1, 2]), tensor([3, 4])] + self.assertEqual(len(result), 2) + self.assertTrue(all(isinstance(x, torch.Tensor) for x in result)) def test_get_kwarg_names(self): - def func1(): - pass + def f1(): pass + self.assertEqual(utils.get_kwarg_names(f1), []) - self.assertEqual(utils.get_kwarg_names(func1), []) + def f2(a): pass + self.assertEqual(utils.get_kwarg_names(f2), ["a"]) - def func2(key1): - pass + def f3(a, b=1): pass + self.assertEqual(utils.get_kwarg_names(f3), ["a", "b"]) - self.assertEqual(utils.get_kwarg_names(func2), ["key1"]) + def f4(a, *args, b=2): pass + self.assertEqual(utils.get_kwarg_names(f4), ["b"]) - def func3(key1, key2=2): - pass + def f5(*args, b, c=2): pass + self.assertEqual(utils.get_kwarg_names(f5), ["b", "c"]) - self.assertEqual(utils.get_kwarg_names(func3), ["key1", "key2"]) + def f6(a, b, *args): pass + self.assertEqual(utils.get_kwarg_names(f6), []) - def func4(key1, *argv, key2=2): - pass + def f7(a, b=1, c=3, **kwargs): pass + self.assertEqual(utils.get_kwarg_names(f7), ["a", "b", "c"]) - self.assertEqual(utils.get_kwarg_names(func4), ["key2"]) + # Built-in function (should not raise) + self.assertIsInstance(utils.get_kwarg_names(len), list) - def func5(*argv, key1, key2=2): - pass + # Lambda + l = lambda a, b=2: a + b + self.assertEqual(utils.get_kwarg_names(l), ["a", "b"]) - self.assertEqual(utils.get_kwarg_names(func5), ["key1", "key2"]) + # Method + self.assertIn("self", utils.get_kwarg_names(DummyClass.method)) - def func6(key1, key2, key3, *argv): - pass + def test_kwarg_has_default(self): + def f1(a, b=2): pass + self.assertFalse(utils.kwarg_has_default(f1, "a")) + self.assertTrue(utils.kwarg_has_default(f1, "b")) - self.assertEqual(utils.get_kwarg_names(func6), []) - - def func7(key1, key2=1, key3=3, **kwargs): - pass - - self.assertEqual(utils.get_kwarg_names(func7), ["key1", "key2", "key3"]) + # Not in function + self.assertFalse(utils.kwarg_has_default(f1, "c")) def test_safe_call(self): - - arguments = { - "key1": None, - "key2": False, - "key_not_in_function": True, - "key_not_in_function_2": True, - } - - def func1(): - pass - - utils.safe_call(func1, **arguments) - - def func2(key1): - pass - - utils.safe_call(func2, **arguments) - - def func3(key1, key2=2): - pass - - utils.safe_call(func3, **arguments) - - def func4(key1, *argv, key2=2): - pass - - self.assertRaises(TypeError, lambda: utils.safe_call(func4, **arguments)) - - def func5(*argv, key1, key2=2): - pass - - utils.safe_call(func5, **arguments) - - def func6(key1, key2=1, key3=3, **kwargs): - pass - - utils.safe_call(func6, **arguments) + def f(a, b=2, c=3): return a + b + c + # All args present + self.assertEqual(utils.safe_call(f, positional_args=[1], b=2, c=3), 6) + # Only some kwargs present + self.assertEqual(utils.safe_call(f, positional_args=[1], b=4), 8) + # No kwargs + self.assertEqual(utils.safe_call(f, positional_args=[1]), 6) + # Extra kwargs are ignored + self.assertEqual(utils.safe_call(f, positional_args=[1], b=5, x=10), 9) + # Only kwargs + self.assertEqual(utils.safe_call(f, a=1, b=2, c=3), 6) + + # Should ignore kwargs not in function signature + def g(a): return a + self.assertEqual(utils.safe_call(g, a=42, extrakw=1), 42) + + # Missing required arg should raise error + def f(a): return a + with self.assertRaises(TypeError): + utils.safe_call(f) + + def g(a, *, b): return a + b + with self.assertRaises(TypeError): + utils.safe_call(g, a=1) # Missing b if __name__ == "__main__": From 04d965bf724d136b3b3742bb6dbb24ee27604e26 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Fri, 30 May 2025 12:54:54 +0200 Subject: [PATCH 27/54] Implemented torch.rand() --- deeptrack/backend/array_api_compat_ext/torch/random.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/deeptrack/backend/array_api_compat_ext/torch/random.py b/deeptrack/backend/array_api_compat_ext/torch/random.py index 2ec864f17..5cb277a47 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/random.py +++ b/deeptrack/backend/array_api_compat_ext/torch/random.py @@ -19,8 +19,12 @@ ] -def rand(*args: int) -> torch.Tensor: - return torch.rand(*args) +def rand( + *args: int, + dtype: torch.dtype=torch.float32, + device: torch.device | str = torch.device("cpu"), +) -> torch.Tensor: + return torch.rand(*args, dtype=dtype, device=device) def random(size: tuple[int, ...] | None = None) -> torch.Tensor: From ab09fb42defd5675c10b96506d8ee13fbdaab274 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Fri, 30 May 2025 13:19:37 +0200 Subject: [PATCH 28/54] Unit test for random.rand() --- deeptrack/tests/backend/test_random.py | 27 ++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 deeptrack/tests/backend/test_random.py diff --git a/deeptrack/tests/backend/test_random.py b/deeptrack/tests/backend/test_random.py new file mode 100644 index 000000000..164cc78ee --- /dev/null +++ b/deeptrack/tests/backend/test_random.py @@ -0,0 +1,27 @@ +import unittest + +import torch + +from deeptrack.backend.array_api_compat_ext.torch import random + + +class TestRandom(unittest.TestCase): + def test_rand(self): + shapes = [(2, ) , (3, 4) ] + dtypes = [torch.float32, torch.float64] + devices = [torch.device("cpu"), "cpu"] + for i in range(len(shapes)): + shape = shapes[i] + dtype = dtypes[i] + device = devices[i] + + torch.manual_seed(1) + expected = torch.rand(*shape, dtype=dtype, device=device) + + torch.manual_seed(1) + generated = rand(*shape, dtype=dtype, device=device) + + self.assertEqual(generated.shape, expected.shape) + self.assertEqual(generated.dtype, expected.dtype) + self.assertEqual(generated.device, expected.device) + self.assertTrue(torch.equal(generated, expected)) From f8017dfa8941736f50bb8cc535273931f0b6b203 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Fri, 30 May 2025 13:20:59 +0200 Subject: [PATCH 29/54] typo --- deeptrack/tests/backend/test_random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeptrack/tests/backend/test_random.py b/deeptrack/tests/backend/test_random.py index 164cc78ee..bbc064acf 100644 --- a/deeptrack/tests/backend/test_random.py +++ b/deeptrack/tests/backend/test_random.py @@ -19,7 +19,7 @@ def test_rand(self): expected = torch.rand(*shape, dtype=dtype, device=device) torch.manual_seed(1) - generated = rand(*shape, dtype=dtype, device=device) + generated = random.rand(*shape, dtype=dtype, device=device) self.assertEqual(generated.shape, expected.shape) self.assertEqual(generated.dtype, expected.dtype) From 6494c440335de087b42e9bd66fb359bf09ff599c Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Fri, 30 May 2025 13:26:30 +0200 Subject: [PATCH 30/54] added torch to requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 9d7eb35e6..228b04918 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ numpy matplotlib scipy scikit-image +torch more_itertools pint pandas From b24081edcd1ca31394eea49d8cf63a5fcecc3fc3 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Fri, 30 May 2025 14:14:39 +0200 Subject: [PATCH 31/54] test_random: mean test and numpy comparison --- deeptrack/tests/backend/test_random.py | 37 +++++++++++--------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/deeptrack/tests/backend/test_random.py b/deeptrack/tests/backend/test_random.py index bbc064acf..ca9d9e364 100644 --- a/deeptrack/tests/backend/test_random.py +++ b/deeptrack/tests/backend/test_random.py @@ -1,27 +1,22 @@ import unittest -import torch - +import numpy as np from deeptrack.backend.array_api_compat_ext.torch import random -class TestRandom(unittest.TestCase): - def test_rand(self): - shapes = [(2, ) , (3, 4) ] - dtypes = [torch.float32, torch.float64] - devices = [torch.device("cpu"), "cpu"] - for i in range(len(shapes)): - shape = shapes[i] - dtype = dtypes[i] - device = devices[i] - - torch.manual_seed(1) - expected = torch.rand(*shape, dtype=dtype, device=device) +class TestRandomNumpy(unittest.TestCase): + def test_rand(self): + shapes = [(2, ), (3, 4)] + dtypes = [torch.float32, torch.float64] + devices = [torch.device("cpu"), "cpu"] + + for shape, dtype, device in zip(shapes, dtypes, devices): + + expected = np.random.rand(*shape) + generated = rand(*shape, dtype=dtype, device=device) + self.assertEqual(generated.shape, expected.shape) + self.assertEqual(generated.dtype, dtype) - torch.manual_seed(1) - generated = random.rand(*shape, dtype=dtype, device=device) - - self.assertEqual(generated.shape, expected.shape) - self.assertEqual(generated.dtype, expected.dtype) - self.assertEqual(generated.device, expected.device) - self.assertTrue(torch.equal(generated, expected)) + a = rand(100, dtype=torch.float32, device="cpu") + b = np.random.rand(100) + self.assertAlmostEqual(a.mean(), np.mean(b), delta = 1) From c99c22eacbeb1fbedd1926c65c54c4493a1f8be6 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Fri, 30 May 2025 14:16:55 +0200 Subject: [PATCH 32/54] Import torch --- deeptrack/tests/backend/test_random.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deeptrack/tests/backend/test_random.py b/deeptrack/tests/backend/test_random.py index ca9d9e364..5b73e393d 100644 --- a/deeptrack/tests/backend/test_random.py +++ b/deeptrack/tests/backend/test_random.py @@ -1,10 +1,11 @@ import unittest import numpy as np +import torch from deeptrack.backend.array_api_compat_ext.torch import random -class TestRandomNumpy(unittest.TestCase): +class TestRandom(unittest.TestCase): def test_rand(self): shapes = [(2, ), (3, 4)] dtypes = [torch.float32, torch.float64] From 39ef3b4f1c45c622fcbe65b2a4a7b10c2b51d0dd Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Fri, 30 May 2025 14:18:44 +0200 Subject: [PATCH 33/54] Update test_random.py --- deeptrack/tests/backend/test_random.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deeptrack/tests/backend/test_random.py b/deeptrack/tests/backend/test_random.py index 5b73e393d..9a98d5ae8 100644 --- a/deeptrack/tests/backend/test_random.py +++ b/deeptrack/tests/backend/test_random.py @@ -14,10 +14,10 @@ def test_rand(self): for shape, dtype, device in zip(shapes, dtypes, devices): expected = np.random.rand(*shape) - generated = rand(*shape, dtype=dtype, device=device) + generated = random.rand(*shape, dtype=dtype, device=device) self.assertEqual(generated.shape, expected.shape) self.assertEqual(generated.dtype, dtype) - a = rand(100, dtype=torch.float32, device="cpu") + a = random.rand(100, dtype=torch.float32, device="cpu") b = np.random.rand(100) self.assertAlmostEqual(a.mean(), np.mean(b), delta = 1) From 0193668dac577072e04d971e7a76f3d467f5825f Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Fri, 30 May 2025 14:38:59 +0200 Subject: [PATCH 34/54] implemented random.beta() --- deeptrack/backend/array_api_compat_ext/torch/random.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deeptrack/backend/array_api_compat_ext/torch/random.py b/deeptrack/backend/array_api_compat_ext/torch/random.py index 5cb277a47..90baac4c7 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/random.py +++ b/deeptrack/backend/array_api_compat_ext/torch/random.py @@ -1,6 +1,7 @@ from __future__ import annotations import torch +import numpy as np __all__ = [ "rand", @@ -42,9 +43,11 @@ def randn(*args: int) -> torch.Tensor: def beta( a: float, b: float, - size: tuple[int, ...] | None = None, + size: int | tuple[int, ...] = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - raise NotImplementedError("the beta distribution is not implemented in torch") + return torch.tensor(np.random.beta(a, b, size), dtype=dtype, device=device) def binomial( From f8c14b559af5b73becc8e0842794c430fc2e8a79 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Thu, 5 Jun 2025 15:44:59 +0200 Subject: [PATCH 35/54] Update test_random.py --- deeptrack/tests/backend/test_random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeptrack/tests/backend/test_random.py b/deeptrack/tests/backend/test_random.py index 9a98d5ae8..3469bcbc3 100644 --- a/deeptrack/tests/backend/test_random.py +++ b/deeptrack/tests/backend/test_random.py @@ -20,4 +20,4 @@ def test_rand(self): a = random.rand(100, dtype=torch.float32, device="cpu") b = np.random.rand(100) - self.assertAlmostEqual(a.mean(), np.mean(b), delta = 1) + self.assertAlmostEqual(a.mean(), np.mean(b), delta=1) # Use a different rand From 6c1eca9e4336e6684a62f5350814dd5acb50d895 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Sat, 28 Jun 2025 18:25:12 +0200 Subject: [PATCH 36/54] binomial added --- .../backend/array_api_compat_ext/torch/random.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/deeptrack/backend/array_api_compat_ext/torch/random.py b/deeptrack/backend/array_api_compat_ext/torch/random.py index 90baac4c7..c0a0218a8 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/random.py +++ b/deeptrack/backend/array_api_compat_ext/torch/random.py @@ -36,8 +36,12 @@ def random_sample(size: tuple[int, ...] | None = None) -> torch.Tensor: return torch.rand(*size) if size else torch.rand() -def randn(*args: int) -> torch.Tensor: - return torch.randn(*args) +def randn( + *args: int, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), +) -> torch.Tensor: + return torch.randn(*args, dtype=dtype, device=device) def beta( @@ -54,8 +58,11 @@ def binomial( n: int, p: float, size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - return torch.bernoulli(torch.full(size, p)) + #return torch.bernoulli(torch.full(size, p)) + return torch.tensor(np.random.binomial(n, p, size), dtype=dtype, device=device) def choice( From 2bab7ccf098b2d6150495bc3d938063cf978afa0 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Mon, 30 Jun 2025 13:06:58 +0200 Subject: [PATCH 37/54] Several functions implemented. --- .../array_api_compat_ext/torch/random.py | 90 +++++++++++++++---- 1 file changed, 72 insertions(+), 18 deletions(-) diff --git a/deeptrack/backend/array_api_compat_ext/torch/random.py b/deeptrack/backend/array_api_compat_ext/torch/random.py index c0a0218a8..69fc7983e 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/random.py +++ b/deeptrack/backend/array_api_compat_ext/torch/random.py @@ -1,3 +1,23 @@ +"""xp compatibility module for Numpy functions + +This module contains wrapper functions for various numpy.random functions +that return torch tensors when and can accept optional `dtype` and`device` arguments. + + +Examples +-------- +Sample the `beta` distribution: + +>>> from torch import cuda, float16 + +>>> if cuda.is_available(): +... print(beta(1, 2, dtype=torch.float16, device="cuda")) + +tensor(0.3315, device='cuda:0', dtype=torch.float16) + + +""" + from __future__ import annotations import torch @@ -28,12 +48,26 @@ def rand( return torch.rand(*args, dtype=dtype, device=device) -def random(size: tuple[int, ...] | None = None) -> torch.Tensor: - return torch.rand(*size) if size else torch.rand() +def random( + size: tuple[int, ...] | None = None, + dtype: torch.dtype=torch.float32, + device: torch.device | str = torch.device("cpu"), +) -> torch.Tensor: + return ( + torch.rand(*size, dtype=dtype, device=device) + if size else torch.rand(dtype=dtype, device=device) + ) -def random_sample(size: tuple[int, ...] | None = None) -> torch.Tensor: - return torch.rand(*size) if size else torch.rand() +def random_sample( + size: tuple[int, ...] | None = None, + dtype: torch.dtype=torch.float32, + device: torch.device | str = torch.device("cpu"), +) -> torch.Tensor: + return ( + torch.rand(*size, dtype=dtype, device=device) + if size else torch.rand(dtype=dtype, device=device) + ) def randn( @@ -51,7 +85,9 @@ def beta( dtype: torch.dtype = torch.float32, device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - return torch.tensor(np.random.beta(a, b, size), dtype=dtype, device=device) + return ( + torch.tensor(np.random.beta(a, b, size), dtype=dtype, device=device) + ) def binomial( @@ -61,62 +97,80 @@ def binomial( dtype: torch.dtype = torch.float32, device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - #return torch.bernoulli(torch.full(size, p)) - return torch.tensor(np.random.binomial(n, p, size), dtype=dtype, device=device) - + return ( + torch.tensor(np.random.binomial(n, p, size), dtype=dtype, device=device) + ) def choice( - a: torch.Tensor, + a: torch.Tensor | np.ndarray, size: tuple[int, ...] | None = None, replace: bool = True, p: torch.Tensor | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), + ) -> torch.Tensor: - raise NotImplementedError( - "the choice function is not implemented in torch" + a_numpy = a.cpu().numpy() + p_numpy = p.cpu().numpy() if p is not None else None + + return ( + torch.tensor( + np.random_choice(a_numpy, size=size, replace=replace, p=p_numpy), dtype=dtype, device=device + ) ) - + def multinomial( n: int, pvals: torch.Tensor, size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - return torch.multinomial(pvals, n, size) + return torch.multinomial(pvals, n, size, dtype=dtype, device=device) def randint( low: int, high: int, size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - return torch.randint(low, high, size) + return torch.randint(low, high, size, dtype=dtype, device=device) def shuffle(x: torch.Tensor) -> torch.Tensor: - return x[torch.randperm(x.shape[0])] + return x[torch.randperm(x.shape[0], device=x.device)] def uniform( low: float, high: float, size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - return torch.rand(*size) * (high - low) + low + return torch.rand(*size, dtype=dtype, device=device) * (high - low) + low def normal( loc: float, scale: float, size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - return torch.randn(*size) * scale + loc + return torch.randn(*size, dtype=dtype, device=device) * scale + loc def poisson( lam: float, size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - return torch.poisson(torch.full(size, lam)) + return torch.poisson(torch.full(size, lam, dtype=dtype, device=device)) # TODO: implement the rest of the functions as they are needed From f0b3179faba319d969200d87e5110fae24fb11eb Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Mon, 30 Jun 2025 14:31:00 +0200 Subject: [PATCH 38/54] Update test_random.py --- deeptrack/tests/backend/test_random.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/deeptrack/tests/backend/test_random.py b/deeptrack/tests/backend/test_random.py index 3469bcbc3..f8cd97976 100644 --- a/deeptrack/tests/backend/test_random.py +++ b/deeptrack/tests/backend/test_random.py @@ -3,7 +3,23 @@ import numpy as np import torch from deeptrack.backend.array_api_compat_ext.torch import random +""" +TODO: Implement tests for all of these functions to start with. + "rand", + "random", + "random_sample", + "randn", + "beta", + "binomial", + "choice", + "multinomial", + "randint", + "shuffle", + "uniform", + "normal", + "poisson", +""" class TestRandom(unittest.TestCase): def test_rand(self): From 6a646d56e1fd783e19cedf5c993b2e1ea39d48d5 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Tue, 1 Jul 2025 10:35:13 +0200 Subject: [PATCH 39/54] Implemented Gamma distribution --- .../array_api_compat_ext/torch/random.py | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/deeptrack/backend/array_api_compat_ext/torch/random.py b/deeptrack/backend/array_api_compat_ext/torch/random.py index 69fc7983e..bb72012fd 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/random.py +++ b/deeptrack/backend/array_api_compat_ext/torch/random.py @@ -19,7 +19,7 @@ """ from __future__ import annotations - +from deeptrack.types import ArrayLike import torch import numpy as np @@ -28,6 +28,7 @@ "random", "random_sample", "randn", + "standard_normal", "beta", "binomial", "choice", @@ -37,6 +38,7 @@ "uniform", "normal", "poisson", + "gamma", ] @@ -77,6 +79,12 @@ def randn( ) -> torch.Tensor: return torch.randn(*args, dtype=dtype, device=device) +def standard_normal( + *args: int, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), +) -> torch.Tensor: + return torch.randn(*args, dtype=dtype, device=device) def beta( a: float, @@ -115,7 +123,9 @@ def choice( return ( torch.tensor( - np.random_choice(a_numpy, size=size, replace=replace, p=p_numpy), dtype=dtype, device=device + np.random_choice( + a_numpy, size=size, replace=replace, p=p_numpy + ), dtype=dtype, device=device ) ) @@ -171,6 +181,23 @@ def poisson( device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: return torch.poisson(torch.full(size, lam, dtype=dtype, device=device)) + +def gamma( + shape: float | ArrayLike[float], + scale: float | ArrayLike[float] = 1.0, + size: tuple[int, ...] | None = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = torch.device("cpu"), +) -> torch.Tensor: + + shape = torch.as_tensor(shape, dtype=dtype, device=device) + scale = torch.as_tensor(scale, dtype=dtype, device=device) + + if size is not None: + shape = shape.expand(size) + scale = scale.expand(size) + gamma_distribution = torch.distributions.Gamma(shape, scale) + return gamma_distribution.sample() # TODO: implement the rest of the functions as they are needed From 5782d55f323ae3cb66de7cd47a84a08fe3353003 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Tue, 1 Jul 2025 11:05:05 +0200 Subject: [PATCH 40/54] added exponential, geometric, multivar, dirichlet --- .../array_api_compat_ext/torch/random.py | 70 ++++++++++++++++--- 1 file changed, 62 insertions(+), 8 deletions(-) diff --git a/deeptrack/backend/array_api_compat_ext/torch/random.py b/deeptrack/backend/array_api_compat_ext/torch/random.py index bb72012fd..78951944b 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/random.py +++ b/deeptrack/backend/array_api_compat_ext/torch/random.py @@ -79,6 +79,7 @@ def randn( ) -> torch.Tensor: return torch.randn(*args, dtype=dtype, device=device) + def standard_normal( *args: int, dtype: torch.dtype = torch.float32, @@ -86,6 +87,7 @@ def standard_normal( ) -> torch.Tensor: return torch.randn(*args, dtype=dtype, device=device) + def beta( a: float, b: float, @@ -109,6 +111,7 @@ def binomial( torch.tensor(np.random.binomial(n, p, size), dtype=dtype, device=device) ) + def choice( a: torch.Tensor | np.ndarray, size: tuple[int, ...] | None = None, @@ -116,11 +119,10 @@ def choice( p: torch.Tensor | None = None, dtype: torch.dtype = torch.float32, device: torch.device | str = torch.device("cpu"), - ) -> torch.Tensor: + a_numpy = a.cpu().numpy() p_numpy = p.cpu().numpy() if p is not None else None - return ( torch.tensor( np.random_choice( @@ -181,10 +183,11 @@ def poisson( device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: return torch.poisson(torch.full(size, lam, dtype=dtype, device=device)) - + + def gamma( - shape: float | ArrayLike[float], - scale: float | ArrayLike[float] = 1.0, + shape: float | torch.Tensor, + scale: float | torch.Tensor = 1.0, size: tuple[int, ...] | None = None, dtype: torch.dtype = torch.float32, device: torch.device | str = torch.device("cpu"), @@ -192,12 +195,63 @@ def gamma( shape = torch.as_tensor(shape, dtype=dtype, device=device) scale = torch.as_tensor(scale, dtype=dtype, device=device) - if size is not None: shape = shape.expand(size) scale = scale.expand(size) + return torch.distributions.Gamma(shape, scale).sample() + + +def exponential( + scale: float | torch.Tensor = 1.0, + size: tuple[int, ...] = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = "cpu", +) -> torch.Tensor: + + rate = torch.as_tensor(1.0/scale, dtype=dtype, device=device) + if size is None: + return torch.distributions.Exponential(rate).sample() + return torch.distributions.Exponential(rate).sample(size) - gamma_distribution = torch.distributions.Gamma(shape, scale) - return gamma_distribution.sample() + +def multivariate_normal( + mean: torch.Tensor, + cov: torch.Tensor, + size: tuple[int, ...]] = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = "cpu", +) -> torch.Tensor: + + mean = mean.to(dtype=dtype, device=device) + cov = cov.to(dtype=dtype, device=device) + if size is None: + return torch.distributions.MultivariateNormal(mean, covariance_matrix=cov).sample() + return torch.distributions.MultivariateNormal(mean, covariance_matrix=cov).sample(size) + + +def geometric( + p: float | torch.Tensor, + size: tuple[int, ...] = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = "cpu", +) -> torch.Tensor: + + p = torch.as_tensor(p, dtype=torch.float32, device=device) + if size is None: + return torch.distributions.Geometric(probs=p).sample().to(dtype) + return torch.distributions.Geometric(probs=p).sample(size).to(dtype) + + +def dirichlet( + alpha: torch.Tensor, + size: tuple[int, ...] = None, + dtype: torch.dtype = torch.float32, + device: torch.device | str = "cpu", +) -> torch.Tensor: + + alpha = alpha.to(dtype=dtype, device=device) + if size is None: + return torch.distributions.Dirichlet(alpha).sample() + return torch.distributions.Dirichlet(alpha).sample(size) # TODO: implement the rest of the functions as they are needed From 1b7307f83758cad31df348f0411e5c8dcfdf8e12 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Tue, 1 Jul 2025 11:11:36 +0200 Subject: [PATCH 41/54] syntax --- deeptrack/backend/array_api_compat_ext/torch/random.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/deeptrack/backend/array_api_compat_ext/torch/random.py b/deeptrack/backend/array_api_compat_ext/torch/random.py index 78951944b..cddc81784 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/random.py +++ b/deeptrack/backend/array_api_compat_ext/torch/random.py @@ -19,7 +19,6 @@ """ from __future__ import annotations -from deeptrack.types import ArrayLike import torch import numpy as np @@ -217,7 +216,7 @@ def exponential( def multivariate_normal( mean: torch.Tensor, cov: torch.Tensor, - size: tuple[int, ...]] = None, + size: tuple[int, ...] = None, dtype: torch.dtype = torch.float32, device: torch.device | str = "cpu", ) -> torch.Tensor: From cefbc936723c875d8d94cf00ae789511ea25d9e8 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Tue, 1 Jul 2025 11:17:48 +0200 Subject: [PATCH 42/54] docs --- deeptrack/backend/array_api_compat_ext/torch/random.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deeptrack/backend/array_api_compat_ext/torch/random.py b/deeptrack/backend/array_api_compat_ext/torch/random.py index cddc81784..d212e037d 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/random.py +++ b/deeptrack/backend/array_api_compat_ext/torch/random.py @@ -1,7 +1,7 @@ -"""xp compatibility module for Numpy functions +"""Compatibility module for Numpy functions -This module contains wrapper functions for various numpy.random functions -that return torch tensors when and can accept optional `dtype` and`device` arguments. +This module contains helper functions for various numpy.random functions +that return torch.Tensors when used. Accept optional `dtype` and`device` arguments. Examples From 063b44adfc6a0134a302f019ca8f8d9066cc6108 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Tue, 1 Jul 2025 11:20:11 +0200 Subject: [PATCH 43/54] formatting, linebreaks --- .../array_api_compat_ext/torch/random.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/deeptrack/backend/array_api_compat_ext/torch/random.py b/deeptrack/backend/array_api_compat_ext/torch/random.py index d212e037d..3e1bb94ac 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/random.py +++ b/deeptrack/backend/array_api_compat_ext/torch/random.py @@ -1,7 +1,7 @@ """Compatibility module for Numpy functions This module contains helper functions for various numpy.random functions -that return torch.Tensors when used. Accept optional `dtype` and`device` arguments. +that return torch.Tensors when used. Accept optional `dtype` and `device` arguments. Examples @@ -95,7 +95,9 @@ def beta( device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: return ( - torch.tensor(np.random.beta(a, b, size), dtype=dtype, device=device) + torch.tensor( + np.random.beta(a, b, size), dtype=dtype, device=device + ) ) @@ -107,7 +109,9 @@ def binomial( device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: return ( - torch.tensor(np.random.binomial(n, p, size), dtype=dtype, device=device) + torch.tensor( + np.random.binomial(n, p, size), dtype=dtype, device=device + ) ) @@ -119,7 +123,6 @@ def choice( dtype: torch.dtype = torch.float32, device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - a_numpy = a.cpu().numpy() p_numpy = p.cpu().numpy() if p is not None else None return ( @@ -149,7 +152,7 @@ def randint( device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: return torch.randint(low, high, size, dtype=dtype, device=device) - + def shuffle(x: torch.Tensor) -> torch.Tensor: return x[torch.randperm(x.shape[0], device=x.device)] @@ -191,7 +194,6 @@ def gamma( dtype: torch.dtype = torch.float32, device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: - shape = torch.as_tensor(shape, dtype=dtype, device=device) scale = torch.as_tensor(scale, dtype=dtype, device=device) if size is not None: @@ -206,7 +208,6 @@ def exponential( dtype: torch.dtype = torch.float32, device: torch.device | str = "cpu", ) -> torch.Tensor: - rate = torch.as_tensor(1.0/scale, dtype=dtype, device=device) if size is None: return torch.distributions.Exponential(rate).sample() @@ -220,12 +221,13 @@ def multivariate_normal( dtype: torch.dtype = torch.float32, device: torch.device | str = "cpu", ) -> torch.Tensor: - mean = mean.to(dtype=dtype, device=device) cov = cov.to(dtype=dtype, device=device) if size is None: - return torch.distributions.MultivariateNormal(mean, covariance_matrix=cov).sample() - return torch.distributions.MultivariateNormal(mean, covariance_matrix=cov).sample(size) + return torch.distributions.MultivariateNormal( + mean, covariance_matrix=cov).sample() + return torch.distributions.MultivariateNormal( + mean, covariance_matrix=cov).sample(size) def geometric( @@ -234,11 +236,12 @@ def geometric( dtype: torch.dtype = torch.float32, device: torch.device | str = "cpu", ) -> torch.Tensor: - p = torch.as_tensor(p, dtype=torch.float32, device=device) if size is None: - return torch.distributions.Geometric(probs=p).sample().to(dtype) - return torch.distributions.Geometric(probs=p).sample(size).to(dtype) + return torch.distributions.Geometric( + probs=p).sample().to(dtype) + return torch.distributions.Geometric( + probs=p).sample(size).to(dtype) def dirichlet( @@ -247,7 +250,6 @@ def dirichlet( dtype: torch.dtype = torch.float32, device: torch.device | str = "cpu", ) -> torch.Tensor: - alpha = alpha.to(dtype=dtype, device=device) if size is None: return torch.distributions.Dirichlet(alpha).sample() From 9aa82b8bdb88a667fadbaf97a3f0eff5dae7f0dc Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Tue, 1 Jul 2025 11:27:26 +0200 Subject: [PATCH 44/54] docs --- deeptrack/backend/array_api_compat_ext/torch/random.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deeptrack/backend/array_api_compat_ext/torch/random.py b/deeptrack/backend/array_api_compat_ext/torch/random.py index 3e1bb94ac..308bc174d 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/random.py +++ b/deeptrack/backend/array_api_compat_ext/torch/random.py @@ -1,7 +1,9 @@ """Compatibility module for Numpy functions -This module contains helper functions for various numpy.random functions -that return torch.Tensors when used. Accept optional `dtype` and `device` arguments. +This module contains helper functions that use the same syntax as +the equivalent numpy.random functions. All functions return +torch.Tensors when used and accept optional `dtype` and `device` +arguments that default to `float32` and `cpu`. Examples From 0a6d561ea8ab553a45221167b6400b5bc0b80881 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Thu, 3 Jul 2025 11:52:58 +0200 Subject: [PATCH 45/54] removed torch from requirements --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 228b04918..9d7eb35e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ numpy matplotlib scipy scikit-image -torch more_itertools pint pandas From 0e8aa7c34c91e9c4db0b8dbfcb20b9883b9cbe4e Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Thu, 3 Jul 2025 13:15:50 +0200 Subject: [PATCH 46/54] check torch availability in tests --- deeptrack/tests/backend/test_random.py | 36 ++++++++++++++------------ 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/deeptrack/tests/backend/test_random.py b/deeptrack/tests/backend/test_random.py index f8cd97976..4affe6d00 100644 --- a/deeptrack/tests/backend/test_random.py +++ b/deeptrack/tests/backend/test_random.py @@ -1,8 +1,11 @@ import unittest import numpy as np -import torch + +from deeptrack.backend import TORCH_AVAILABLE from deeptrack.backend.array_api_compat_ext.torch import random + + """ TODO: Implement tests for all of these functions to start with. "rand", @@ -22,18 +25,19 @@ """ class TestRandom(unittest.TestCase): - def test_rand(self): - shapes = [(2, ), (3, 4)] - dtypes = [torch.float32, torch.float64] - devices = [torch.device("cpu"), "cpu"] - - for shape, dtype, device in zip(shapes, dtypes, devices): - - expected = np.random.rand(*shape) - generated = random.rand(*shape, dtype=dtype, device=device) - self.assertEqual(generated.shape, expected.shape) - self.assertEqual(generated.dtype, dtype) - - a = random.rand(100, dtype=torch.float32, device="cpu") - b = np.random.rand(100) - self.assertAlmostEqual(a.mean(), np.mean(b), delta=1) # Use a different rand + if TORCH_AVAILABLE: + def test_rand(self): + shapes = [(2, ), (3, 4)] + dtypes = [torch.float32, torch.float64] + devices = [torch.device("cpu"), "cpu"] + + for shape, dtype, device in zip(shapes, dtypes, devices): + + expected = np.random.rand(*shape) + generated = random.rand(*shape, dtype=dtype, device=device) + self.assertEqual(generated.shape, expected.shape) + self.assertEqual(generated.dtype, dtype) + + a = random.rand(100, dtype=torch.float32, device="cpu") + b = np.random.rand(100) + self.assertAlmostEqual(a.mean(), np.mean(b), delta=1) # Use a different rand From 296ca89f06d62dc04aa19242719ac5d1b3075c8d Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Thu, 3 Jul 2025 13:38:15 +0200 Subject: [PATCH 47/54] Update test_random.py --- deeptrack/tests/backend/test_random.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/deeptrack/tests/backend/test_random.py b/deeptrack/tests/backend/test_random.py index 4affe6d00..4f461956a 100644 --- a/deeptrack/tests/backend/test_random.py +++ b/deeptrack/tests/backend/test_random.py @@ -23,9 +23,10 @@ "poisson", """ - -class TestRandom(unittest.TestCase): - if TORCH_AVAILABLE: +if TORCH_AVAILABLE: + import torch + class TestRandom(unittest.TestCase): + def test_rand(self): shapes = [(2, ), (3, 4)] dtypes = [torch.float32, torch.float64] @@ -41,3 +42,6 @@ def test_rand(self): a = random.rand(100, dtype=torch.float32, device="cpu") b = np.random.rand(100) self.assertAlmostEqual(a.mean(), np.mean(b), delta=1) # Use a different rand + + if __name__ == "__main__": + unittest.main() From fa4f6cdf3f01d89ca6660453d6fa9eaeb58502f4 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Thu, 3 Jul 2025 13:52:57 +0200 Subject: [PATCH 48/54] Update __init__.py --- deeptrack/backend/array_api_compat_ext/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deeptrack/backend/array_api_compat_ext/__init__.py b/deeptrack/backend/array_api_compat_ext/__init__.py index e67f772af..87c90da85 100644 --- a/deeptrack/backend/array_api_compat_ext/__init__.py +++ b/deeptrack/backend/array_api_compat_ext/__init__.py @@ -1,4 +1,6 @@ -from array_api_compat import torch as apctorch +from deeptrack import TORCH_AVAILABLE +if TORCH_AVAILABLE: + from array_api_compat import torch as apctorch from deeptrack.backend.array_api_compat_ext.torch import random # NumPy and PyTorch random functions are incompatible with each other. From fa30c97e6efcaaae3d579fa923ea4101836280f9 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Thu, 3 Jul 2025 13:54:39 +0200 Subject: [PATCH 49/54] Update __init__.py --- deeptrack/backend/array_api_compat_ext/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeptrack/backend/array_api_compat_ext/__init__.py b/deeptrack/backend/array_api_compat_ext/__init__.py index 87c90da85..d6afce249 100644 --- a/deeptrack/backend/array_api_compat_ext/__init__.py +++ b/deeptrack/backend/array_api_compat_ext/__init__.py @@ -1,7 +1,7 @@ from deeptrack import TORCH_AVAILABLE if TORCH_AVAILABLE: from array_api_compat import torch as apctorch -from deeptrack.backend.array_api_compat_ext.torch import random + from deeptrack.backend.array_api_compat_ext.torch import random # NumPy and PyTorch random functions are incompatible with each other. # The current array_api_compat module does not fix this incompatibility. From d2a313f66fc66d35cd983b41e460a9fbd7b21b86 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Thu, 3 Jul 2025 13:54:50 +0200 Subject: [PATCH 50/54] Update __init__.py --- deeptrack/backend/array_api_compat_ext/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeptrack/backend/array_api_compat_ext/__init__.py b/deeptrack/backend/array_api_compat_ext/__init__.py index d6afce249..12e2309f1 100644 --- a/deeptrack/backend/array_api_compat_ext/__init__.py +++ b/deeptrack/backend/array_api_compat_ext/__init__.py @@ -7,4 +7,4 @@ # The current array_api_compat module does not fix this incompatibility. # So we implement our own patch, which implements a numpy-compatible interface # for the torch random functions. -apctorch.random = random + apctorch.random = random From 328a69b3e21b69379998809de11d455639cbe62c Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Thu, 3 Jul 2025 15:03:29 +0200 Subject: [PATCH 51/54] Update random.py --- deeptrack/backend/array_api_compat_ext/torch/random.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deeptrack/backend/array_api_compat_ext/torch/random.py b/deeptrack/backend/array_api_compat_ext/torch/random.py index 308bc174d..5a75825d3 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/random.py +++ b/deeptrack/backend/array_api_compat_ext/torch/random.py @@ -21,9 +21,13 @@ """ from __future__ import annotations -import torch +from deeptrack import TORCH_AVAILABLE + import numpy as np +if TORCH_AVAILABLE: + import torch + __all__ = [ "rand", "random", From a4fb70abb0c75f9d5258b99a624003a4217b5f94 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Thu, 3 Jul 2025 15:04:24 +0200 Subject: [PATCH 52/54] Update __init__.py --- deeptrack/backend/array_api_compat_ext/torch/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deeptrack/backend/array_api_compat_ext/torch/__init__.py b/deeptrack/backend/array_api_compat_ext/torch/__init__.py index bc9dc47f2..2d2c1dc3d 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/__init__.py +++ b/deeptrack/backend/array_api_compat_ext/torch/__init__.py @@ -1,4 +1,7 @@ -from deeptrack.backend.array_api_compat_ext.torch import random +from deeptrack import TORCH_AVAILABLE + +if TORCH_AVAILABLE: + from deeptrack.backend.array_api_compat_ext.torch import random __all__ = ["random"] From 99771c19cf07766cefcc427cb7e808882a225fa2 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Mon, 25 Aug 2025 14:48:58 +0200 Subject: [PATCH 53/54] Update random.py --- deeptrack/backend/array_api_compat_ext/torch/random.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deeptrack/backend/array_api_compat_ext/torch/random.py b/deeptrack/backend/array_api_compat_ext/torch/random.py index 5a75825d3..c9ae74481 100644 --- a/deeptrack/backend/array_api_compat_ext/torch/random.py +++ b/deeptrack/backend/array_api_compat_ext/torch/random.py @@ -138,7 +138,7 @@ def choice( ), dtype=dtype, device=device ) ) - + def multinomial( n: int, @@ -158,7 +158,7 @@ def randint( device: torch.device | str = torch.device("cpu"), ) -> torch.Tensor: return torch.randint(low, high, size, dtype=dtype, device=device) - + def shuffle(x: torch.Tensor) -> torch.Tensor: return x[torch.randperm(x.shape[0], device=x.device)] From cad33b87d30f3c6e9985e46d0b5cb622044797b0 Mon Sep 17 00:00:00 2001 From: Alex <95913221+Pwhsky@users.noreply.github.com> Date: Mon, 25 Aug 2025 15:01:46 +0200 Subject: [PATCH 54/54] decorator added for test --- deeptrack/tests/backend/test_random.py | 50 +++++++++++++------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/deeptrack/tests/backend/test_random.py b/deeptrack/tests/backend/test_random.py index 4f461956a..e145a3ddb 100644 --- a/deeptrack/tests/backend/test_random.py +++ b/deeptrack/tests/backend/test_random.py @@ -2,8 +2,11 @@ import numpy as np -from deeptrack.backend import TORCH_AVAILABLE -from deeptrack.backend.array_api_compat_ext.torch import random +from deeptrack.backend import TORCH_AVAILABLE + +if TORCH_AVAILABLE: + import torch + from deeptrack.backend.array_api_compat_ext.torch import random """ @@ -23,25 +26,24 @@ "poisson", """ -if TORCH_AVAILABLE: - import torch - class TestRandom(unittest.TestCase): - - def test_rand(self): - shapes = [(2, ), (3, 4)] - dtypes = [torch.float32, torch.float64] - devices = [torch.device("cpu"), "cpu"] - - for shape, dtype, device in zip(shapes, dtypes, devices): - - expected = np.random.rand(*shape) - generated = random.rand(*shape, dtype=dtype, device=device) - self.assertEqual(generated.shape, expected.shape) - self.assertEqual(generated.dtype, dtype) - - a = random.rand(100, dtype=torch.float32, device="cpu") - b = np.random.rand(100) - self.assertAlmostEqual(a.mean(), np.mean(b), delta=1) # Use a different rand - - if __name__ == "__main__": - unittest.main() +@unittest.skipUnless(TORCH_AVAILABLE, "PyTorch is not installed.") +class TestRandom(unittest.TestCase): + + def test_rand(self): + shapes = [(2, ), (3, 4)] + dtypes = [torch.float32, torch.float64] + devices = [torch.device("cpu"), "cpu"] + + for shape, dtype, device in zip(shapes, dtypes, devices): + + expected = np.random.rand(*shape) + generated = random.rand(*shape, dtype=dtype, device=device) + self.assertEqual(generated.shape, expected.shape) + self.assertEqual(generated.dtype, dtype) + + a = random.rand(100, dtype=torch.float32, device="cpu") + b = np.random.rand(100) + self.assertAlmostEqual(a.mean(), np.mean(b), delta=1) # Use a different rand + +if __name__ == "__main__": + unittest.main()