Skip to content

Commit 5d243a4

Browse files
authored
updates pyrshow error message, adds tests (#40)
1 parent 852f45e commit 5d243a4

File tree

2 files changed

+94
-15
lines changed

2 files changed

+94
-15
lines changed

TESTS/unitTests.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,6 +1460,32 @@ def test_animshow_fail_n_frames(self):
14601460
with self.assertRaises(Exception):
14611461
fig = pt.animshow([vid1, vid2], as_html5=False)._fig
14621462

1463+
1464+
class TestPyrshow(unittest.TestCase):
1465+
1466+
def test_pyrshow_1d(self):
1467+
signal = np.random.rand(256)
1468+
pyr = pt.pyramids.GaussianPyramid(signal)
1469+
pt.pyrshow(pyr.pyr_coeffs)
1470+
1471+
def test_pyrshow_1d_weird_shape(self):
1472+
# unlike 2d pyrshow, 1d pyrshow works with any shapes
1473+
signal = np.random.rand(255)
1474+
pyr = pt.pyramids.GaussianPyramid(signal)
1475+
pt.pyrshow(pyr.pyr_coeffs)
1476+
1477+
def test_pyrshow_2d(self):
1478+
img = np.random.rand(256, 256)
1479+
pyr = pt.pyramids.GaussianPyramid(img)
1480+
pt.pyrshow(pyr.pyr_coeffs)
1481+
1482+
def test_pyrshow_2d_shape_err(self):
1483+
img = np.random.rand(255, 255)
1484+
pyr = pt.pyramids.GaussianPyramid(img)
1485+
with self.assertRaises(ValueError):
1486+
pt.pyrshow(pyr.pyr_coeffs)
1487+
1488+
14631489
def main():
14641490
unittest.main()
14651491

src/pyrtools/tools/display.py

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,8 @@ def colormap_range(image, vrange='indep1', cmap=None):
324324
return vrange_list, cmap
325325

326326

327-
def find_zooms(images, video=False):
328-
"""find the zooms necessary to display a list of images
327+
def _check_shapes(images, video=False):
328+
"""Helper function to check whether images can be zoomed in appropriately.
329329
330330
this convenience function takes a list of images and finds out if they can all be displayed at
331331
the same size. for this to be the case, there must be an integer for each image such that the
@@ -341,28 +341,68 @@ def find_zooms(images, video=False):
341341
342342
Returns
343343
-------
344-
zooms : `list`
345-
list of integers showing how much each image needs to be zoomed
346344
max_shape : `tuple`
347345
2-tuple of integers, showing the shape of the largest image in the list
348346
347+
Raises
348+
------
349+
ValueError :
350+
If the images cannot be zoomed to the same. that is, if there is not an integer
351+
for each image such that the image can be multiplied by that integer to be the
352+
same size as the biggest image.
349353
"""
350354
def check_shape_1d(shapes):
351355
max_shape = np.max(shapes)
352356
for s in shapes:
353357
if not (max_shape % s) == 0:
354-
raise Exception("All images must be able to be 'zoomed in' to the largest image."
355-
"That is, the largest image must be a scalar multiple of all "
356-
"images.")
358+
raise ValueError("All images must be able to be 'zoomed in' to the largest image."
359+
"That is, the largest image must be a scalar multiple of all "
360+
"images.")
357361
return max_shape
358-
359362
if video:
360363
time_dim = 1
361364
else:
362365
time_dim = 0
363366
max_shape = []
364367
for i in range(2):
365368
max_shape.append(check_shape_1d([img.shape[i+time_dim] for img in images]))
369+
return max_shape
370+
371+
372+
def find_zooms(images, video=False):
373+
"""find the zooms necessary to display a list of images
374+
375+
Arguments
376+
---------
377+
images : `list`
378+
list of numpy arrays to check the size of. In practice, these are 1d or 2d, but can in
379+
principle be any number of dimensions
380+
video: bool, optional (default False)
381+
handling signals in both space and time or only space.
382+
383+
Returns
384+
-------
385+
zooms : `list`
386+
list of integers showing how much each image needs to be zoomed
387+
max_shape : `tuple`
388+
2-tuple of integers, showing the shape of the largest image in the list
389+
390+
Raises
391+
------
392+
ValueError :
393+
If the images cannot be zoomed to the same. that is, if there is not an integer
394+
for each image such that the image can be multiplied by that integer to be the
395+
same size as the biggest image.
396+
ValueError :
397+
If the two image dimensions require different levels of zoom (e.g., if the
398+
height must be zoomed by 2 but the width must be zoomed by 3).
399+
400+
"""
401+
max_shape = _check_shapes(images, video)
402+
if video:
403+
time_dim = 1
404+
else:
405+
time_dim = 0
366406
zooms = []
367407
for img in images:
368408
# this checks that there's only one unique value in the list
@@ -373,8 +413,8 @@ def check_shape_1d(shapes):
373413
# the first two non-time dimensions (so we'll ignore the RGBA channel
374414
# if any image has that)
375415
if len(set([s // img.shape[i+time_dim] for i, s in enumerate(max_shape)])) > 1:
376-
raise Exception("Both height and width must be multiplied by same amount but got "
377-
"image shape {} and max_shape {}!".format(img.shape, max_shape))
416+
raise ValueError("Both height and width must be multiplied by same amount but got "
417+
"image shape {} and max_shape {}!".format(img.shape, max_shape))
378418
zooms.append(max_shape[0] // img.shape[0])
379419
return zooms, max_shape
380420

@@ -839,9 +879,6 @@ def animate_video(t):
839879
def pyrshow(pyr_coeffs, is_complex=False, vrange='indep1', col_wrap=None, zoom=1, show_residuals=True, **kwargs):
840880
"""Display the coefficients of the pyramid in an orderly fashion
841881
842-
NOTE: this currently only works for 2d signals. we still need to figure out how to handle 1D
843-
signals.
844-
845882
Arguments
846883
---------
847884
pyr_coeffs : `dict`
@@ -894,8 +931,6 @@ def pyrshow(pyr_coeffs, is_complex=False, vrange='indep1', col_wrap=None, zoom=1
894931
# pasting all coefficients into a giant array.
895932
# and the steerable pyramids have a num_orientations attribute
896933

897-
# TODO make list of different elements in each dim
898-
# then only loop through those - see below line 655
899934
num_scales = np.max(np.array([k for k in pyr_coeffs.keys() if isinstance(k, tuple)])[:,0]) + 1
900935
num_orientations = np.max(np.array([k for k in pyr_coeffs.keys() if isinstance(k, tuple)])[:,1]) + 1
901936

@@ -939,4 +974,22 @@ def pyrshow(pyr_coeffs, is_complex=False, vrange='indep1', col_wrap=None, zoom=1
939974
ax.set_title(titles[i])
940975
return fig
941976
else:
977+
try:
978+
_check_shapes(imgs)
979+
except ValueError:
980+
err_scales = num_scales
981+
residual_err_msg = ""
982+
shapes = [(imgs[0].shape[0]/ 2**i, imgs[0].shape[1] / 2**i) for i in range(err_scales)]
983+
err_msg = [f"scale {i} shape: {sh}" for i, sh in enumerate(shapes)]
984+
if show_residuals:
985+
err_scales += 1
986+
residual_err_msg = ", plus 1 (for the residual lowpass)"
987+
shape = (imgs[0].shape[0]/ int(2**err_scales), imgs[0].shape[1] / int(2**err_scales))
988+
err_msg.append(f"residual lowpass shape: {shape}")
989+
err_msg = "\n".join(err_msg)
990+
raise ValueError("In order to correctly display pyramid coefficients, the shape of"
991+
f" the initial image must be evenly divisible by two {err_scales} "
992+
"times, where this number is the height of the "
993+
f"pyramid{residual_err_msg}. "
994+
f"Instead, found:\n{err_msg}")
942995
return imshow(imgs, vrange=vrange, col_wrap=col_wrap, zoom=zoom, title=titles, **kwargs)

0 commit comments

Comments
 (0)