Skip to content

Commit 98a19cc

Browse files
authored
updates steerpyr behavior (#39)
* updates steerpyr behavior to match plenoptic * replace Exception with ValueError * adds height check * raise ValueError if is_complex and order==0
1 parent 5d243a4 commit 98a19cc

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

src/pyrtools/pyramids/SteerablePyramidFreq.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ class SteerablePyramidFreq(SteerablePyramidBase):
2020
The squared radial functions tile the Fourier plane with a raised-cosine
2121
falloff. Angular functions are cos(theta- k*pi/order+1)^(order).
2222
23+
Note that reconstruction will not be exact if the image has an odd shape (due to
24+
boundary-handling issues) or if the pyramid is complex with order=0.
25+
2326
Notes
2427
-----
2528
Transform described in [1]_, filter kernel design described in [2]_.
@@ -30,7 +33,7 @@ class SteerablePyramidFreq(SteerablePyramidBase):
3033
2d image upon which to construct to the pyramid.
3134
height : 'auto' or `int`.
3235
The height of the pyramid. If 'auto', will automatically determine based on the size of
33-
`image`.
36+
`image`. If an int, must be non-negative. When height=0, only returns the residuals.
3437
order : `int`.
3538
The Gaussian derivative order used for the steerable filters. Default value is 3.
3639
Note that to achieve steerability the minimum number of orientation is `order` + 1,
@@ -52,7 +55,8 @@ class SteerablePyramidFreq(SteerablePyramidBase):
5255
Human-readable string specifying the type of pyramid. For base class, is None.
5356
pyr_coeffs : `dict`
5457
Dictionary containing the coefficients of the pyramid. Keys are `(level, band)` tuples and
55-
values are 1d or 2d numpy arrays (same number of dimensions as the input image)
58+
values are 1d or 2d numpy arrays (same number of dimensions as the input image),
59+
running from fine to coarse.
5660
pyr_size : `dict`
5761
Dictionary containing the sizes of the pyramid coefficients. Keys are `(level, band)`
5862
tuples and values are tuples.
@@ -66,6 +70,7 @@ class SteerablePyramidFreq(SteerablePyramidBase):
6670
Oct 1995.
6771
.. [2] A Karasaridis and E P Simoncelli, "A Filter Design Technique for Steerable Pyramid
6872
Image Transforms", ICASSP, Atlanta, GA, May 1996.
73+
6974
"""
7075
def __init__(self, image, height='auto', order=3, twidth=1, is_complex=False):
7176
# in the Fourier domain, there's only one choice for how do edge-handling: circular. to
@@ -78,24 +83,35 @@ def __init__(self, image, height='auto', order=3, twidth=1, is_complex=False):
7883
self.filters = {}
7984
self.order = int(order)
8085

86+
if (image.shape[0] % 2 != 0) or (image.shape[1] % 2 != 0):
87+
warnings.warn("Reconstruction will not be perfect with odd-sized images")
88+
89+
if self.order == 0 and self.is_complex:
90+
raise ValueError(
91+
"Complex pyramid cannot have order=0! See "
92+
"https://github.com/plenoptic-org/plenoptic/issues/326 "
93+
"for an explanation."
94+
)
95+
8196
# we can't use the base class's _set_num_scales method because the max height is calculated
8297
# slightly differently
8398
max_ht = np.floor(np.log2(min(self.image.shape))) - 2
8499
if height == 'auto' or height is None:
85100
self.num_scales = int(max_ht)
86101
elif height > max_ht:
87-
raise Exception("Cannot build pyramid higher than %d levels." % (max_ht))
102+
raise ValueError("Cannot build pyramid higher than %d levels." % (max_ht))
103+
elif height < 0:
104+
raise ValueError("Height must be a non-negative int.")
88105
else:
89106
self.num_scales = int(height)
90107

91108
if self.order > 15 or self.order < 0:
92-
raise Exception("order must be an integer in the range [0,15]. Truncating.")
109+
raise ValueError("order must be an integer in the range [0,15].")
93110

94111
self.num_orientations = int(order + 1)
95112

96113
if twidth <= 0:
97-
warnings.warn("twidth must be positive. Setting to 1.")
98-
twidth = 1
114+
raise ValueError("twidth must be positive.")
99115
twidth = int(twidth)
100116

101117
dims = np.array(self.image.shape)
@@ -220,8 +236,7 @@ def recon_pyr(self, levels='all', bands='all', twidth=1):
220236
221237
"""
222238
if twidth <= 0:
223-
warnings.warn("twidth must be positive. Setting to 1.")
224-
twidth = 1
239+
raise ValueError("twidth must be positive.")
225240

226241
recon_keys = self._recon_keys(levels, bands)
227242

src/pyrtools/pyramids/pyramid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _set_num_scales(self, filter_name, height, extra_height=0):
9696
if height == 'auto':
9797
self.num_scales = max_ht
9898
elif height > max_ht:
99-
raise Exception("Cannot build pyramid higher than %d levels." % (max_ht))
99+
raise ValueError("Cannot build pyramid higher than %d levels." % (max_ht))
100100
else:
101101
self.num_scales = int(height)
102102

0 commit comments

Comments
 (0)