Skip to content

Image/pad_image_to_fft #408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions deeptrack/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,17 @@ class is central to DeepTrack2, acting as a container for numerical data
import operator as ops
from typing import Any, Callable, Iterable

import array_api_compat as apc
import numpy as np
from numpy.typing import NDArray

from deeptrack.backend import config, TORCH_AVAILABLE, xp
from deeptrack.properties import Property
from deeptrack.types import NumberLike

if TORCH_AVAILABLE:
import torch


#TODO ***??*** revise _binary_method - typing, docstring, unit test
def _binary_method(
Expand Down Expand Up @@ -1694,28 +1700,27 @@ def coerce(
_FASTEST_SIZES = np.sort(_FASTEST_SIZES)


#TODO ***??*** revise pad_image_to_fft - typing, docstring, unit test
def pad_image_to_fft(
image: Image | np.ndarray | np.ndarray,
image: Image | NDArray | torch.Tensor,
axes: Iterable[int] = (0, 1),
) -> Image | np.ndarray:
"""Pads an image to optimize Fast Fourier Transform (FFT) performance.
) -> Image | NDArray | torch.Tensor:
"""Pad an image to optimize Fast Fourier Transform (FFT) performance.

This function pads an image by adding zeros to the end of specified axes
so that their lengths match the nearest larger size in `_FASTEST_SIZES`.
These sizes are selected to optimize FFT computations.

Parameters
----------
image: Image | np.ndarray
image: Image | np.ndarray | torch.tensor
The input image to pad. It should be an instance of the `Image` class
or any array-like structure compatible with FFT operations.
axes: Iterable[int], optional
The axes along which to apply padding. Defaults to `(0, 1)`.

Returns
-------
Image | np.ndarray
Image | np.ndarray | torch.tensor
The padded image with dimensions optimized for FFT performance.

Raises
Expand All @@ -1729,26 +1734,30 @@ def pad_image_to_fft(
>>> from deeptrack.image import Image, pad_image_to_fft

Pad an Image object:

>>> img = Image(np.zeros((7, 13)))
>>> img = Image(np.ones((7, 13)))
>>> padded_img = pad_image_to_fft(img)
>>> print(padded_img.shape)
(8, 16)

Pad a NumPy array:

>>> img = np.zeros((5, 11)))
>>> img = np.ones((5, 11))
>>> padded_img = pad_image_to_fft(img)
>>> print(padded_img.shape)
(6, 12)

Pad a PyTorch tensor:
>>> img = torch.ones(7, 11)
>>> padded_img = pad_image_to_fft(img)
>>> print(padded_img.shape)
(8, 12)

"""

def _closest(
dim: int,
) -> int:

# Returns the smallest value frin _FASTEST_SIZES larger than dim.
# Returns the smallest value from _FASTEST_SIZES that is >= dim.
for size in _FASTEST_SIZES:
if size >= dim:
return size
Expand All @@ -1763,7 +1772,18 @@ def _closest(
new_shape[axis] = _closest(new_shape[axis])

# Calculate the padding for each axis.
pad_width = [(0, increase) for increase in np.array(new_shape) - image.shape]
pad_width = [
(0, increase)
for increase in np.array(new_shape) - np.array(image.shape)
]

# Apply zero-padding with torch.nn.functional.pad if the input is a
# PyTorch tensor
if apc.is_torch_array(image):
pad = []
for before, after in reversed(pad_width):
pad.extend([before, after])
return torch.nn.functional.pad(image, pad, mode="constant", value=0)

# Pad the image using constant mode (add zeros).
# Apply zero-padding with np.pad if the input is a NumPy array or an Image
return np.pad(image, pad_width, mode="constant")
33 changes: 32 additions & 1 deletion deeptrack/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

import numpy as np

from deeptrack import features, image
from deeptrack import features, image, TORCH_AVAILABLE

if TORCH_AVAILABLE:
import torch


class TestImage(unittest.TestCase):
Expand Down Expand Up @@ -389,6 +392,7 @@ def test_Image__view(self):

def test_pad_image_to_fft(self):

# Test with dt.Image
input_image = image.Image(np.zeros((7, 25)))
padded_image = image.pad_image_to_fft(input_image)
self.assertEqual(padded_image.shape, (8, 27))
Expand All @@ -401,6 +405,33 @@ def test_pad_image_to_fft(self):
padded_image = image.pad_image_to_fft(input_image)
self.assertEqual(padded_image.shape, (324, 432))

# Test with NumPy array
input_image = np.ones((7, 13))
padded_image = image.pad_image_to_fft(input_image)
self.assertEqual(padded_image.shape, (8, 16))

input_image = np.ones((5,))
padded_image = image.pad_image_to_fft(input_image, axes=(0,))
self.assertEqual(padded_image.shape, (6,))

### Test with PyTorch tensor (if available)
if TORCH_AVAILABLE:
input_image = torch.ones(3, 5)
padded_image = image.pad_image_to_fft(input_image)
self.assertEqual(padded_image.shape, (3, 6))
self.assertIsInstance(padded_image, torch.Tensor)

input_image = torch.ones(5, 7, 11, 13)
padded_image = image.pad_image_to_fft(input_image, axes=(0, 1, 3))
padded_image_np = image.pad_image_to_fft(
input_image.numpy(), axes=(0, 1, 3)
)
self.assertEqual(padded_image.shape, (6, 8, 11, 16))
self.assertIsInstance(padded_image, torch.Tensor)
np.testing.assert_allclose(
padded_image.numpy(), padded_image_np, atol=1e-6
)


if __name__ == "__main__":
unittest.main()
Loading