@@ -280,7 +280,7 @@ def test_evaluation_methods(self, K=30, F=5, N=100):
280
280
281
281
def test_filter_identity (self , M = 10 , c = 2.3 ):
282
282
r"""Test that filtering with c0 only scales the signal."""
283
- x = self ._rs .uniform (size = (M , 1 , self ._G .N ))
283
+ x = self ._rs .uniform (size = (M , self ._G .N ))
284
284
f = filters .Chebyshev (self ._G , c )
285
285
y = f .filter (x , method = 'recursive' )
286
286
np .testing .assert_equal (y , c * x )
@@ -331,3 +331,56 @@ def test_approximations(self, N=100, K=20):
331
331
y1 = f1 .filter (x .T ).T
332
332
y2 = f2 .filter (x )
333
333
np .testing .assert_allclose (y2 .squeeze (), y1 )
334
+
335
+ def test_shape_normalization (self ):
336
+ """Test that signal's shapes are properly normalized."""
337
+ # TODO: should also test filters which are not approximations.
338
+
339
+ def test_normalization (M , Fin , Fout , K = 7 ):
340
+
341
+ def test_shape (y , M , Fout , N = self ._G .N ):
342
+ """Test that filtered signals are squeezed."""
343
+ if Fout == 1 and M == 1 :
344
+ self .assertEqual (y .shape , (N ,))
345
+ elif Fout == 1 :
346
+ self .assertEqual (y .shape , (M , N ))
347
+ elif M == 1 :
348
+ self .assertEqual (y .shape , (Fout , N ))
349
+ else :
350
+ self .assertEqual (y .shape , (M , Fout , N ))
351
+
352
+ coefficients = self ._rs .uniform (size = (K , Fout , Fin ))
353
+ f = filters .Chebyshev (self ._G , coefficients )
354
+ assert f .shape == (Fin , Fout )
355
+ assert (f .n_features_in , f .n_features_out ) == (Fin , Fout )
356
+
357
+ x = self ._rs .uniform (size = (M , Fin , self ._G .N ))
358
+ y = f .filter (x )
359
+ test_shape (y , M , Fout )
360
+
361
+ if Fin == 1 or M == 1 :
362
+ # It only makes sense to squeeze if one dimension is unitary.
363
+ x = x .squeeze ()
364
+ y = f .filter (x )
365
+ test_shape (y , M , Fout )
366
+
367
+ # Test all possible correct combinations of input and output signals.
368
+ for M in [1 , 9 ]:
369
+ for Fin in [1 , 3 ]:
370
+ for Fout in [1 , 5 ]:
371
+ test_normalization (M , Fin , Fout )
372
+
373
+ # Test failure cases.
374
+ M , Fin , Fout , K = 9 , 3 , 5 , 7
375
+ coefficients = self ._rs .uniform (size = (K , Fout , Fin ))
376
+ f = filters .Chebyshev (self ._G , coefficients )
377
+ x = self ._rs .uniform (size = (M , Fin , 2 ))
378
+ self .assertRaises (ValueError , f .filter , x )
379
+ x = self ._rs .uniform (size = (M , 2 , self ._G .N ))
380
+ self .assertRaises (ValueError , f .filter , x )
381
+ x = self ._rs .uniform (size = (2 , self ._G .N ))
382
+ self .assertRaises (ValueError , f .filter , x )
383
+ x = self ._rs .uniform (size = (self ._G .N ))
384
+ self .assertRaises (ValueError , f .filter , x )
385
+ x = self ._rs .uniform (size = (2 , M , Fin , self ._G .N ))
386
+ self .assertRaises (ValueError , f .filter , x )
0 commit comments