@@ -280,7 +280,7 @@ def test_evaluation_methods(self, K=30, F=5, N=100):
280280
281281 def test_filter_identity (self , M = 10 , c = 2.3 ):
282282 r"""Test that filtering with c0 only scales the signal."""
283- x = self ._rs .uniform (size = (M , 1 , self ._G .N ))
283+ x = self ._rs .uniform (size = (M , self ._G .N ))
284284 f = filters .Chebyshev (self ._G , c )
285285 y = f .filter (x , method = 'recursive' )
286286 np .testing .assert_equal (y , c * x )
@@ -331,3 +331,56 @@ def test_approximations(self, N=100, K=20):
331331 y1 = f1 .filter (x .T ).T
332332 y2 = f2 .filter (x )
333333 np .testing .assert_allclose (y2 .squeeze (), y1 )
334+
335+ def test_shape_normalization (self ):
336+ """Test that signal's shapes are properly normalized."""
337+ # TODO: should also test filters which are not approximations.
338+
339+ def test_normalization (M , Fin , Fout , K = 7 ):
340+
341+ def test_shape (y , M , Fout , N = self ._G .N ):
342+ """Test that filtered signals are squeezed."""
343+ if Fout == 1 and M == 1 :
344+ self .assertEqual (y .shape , (N ,))
345+ elif Fout == 1 :
346+ self .assertEqual (y .shape , (M , N ))
347+ elif M == 1 :
348+ self .assertEqual (y .shape , (Fout , N ))
349+ else :
350+ self .assertEqual (y .shape , (M , Fout , N ))
351+
352+ coefficients = self ._rs .uniform (size = (K , Fout , Fin ))
353+ f = filters .Chebyshev (self ._G , coefficients )
354+ assert f .shape == (Fin , Fout )
355+ assert (f .n_features_in , f .n_features_out ) == (Fin , Fout )
356+
357+ x = self ._rs .uniform (size = (M , Fin , self ._G .N ))
358+ y = f .filter (x )
359+ test_shape (y , M , Fout )
360+
361+ if Fin == 1 or M == 1 :
362+ # It only makes sense to squeeze if one dimension is unitary.
363+ x = x .squeeze ()
364+ y = f .filter (x )
365+ test_shape (y , M , Fout )
366+
367+ # Test all possible correct combinations of input and output signals.
368+ for M in [1 , 9 ]:
369+ for Fin in [1 , 3 ]:
370+ for Fout in [1 , 5 ]:
371+ test_normalization (M , Fin , Fout )
372+
373+ # Test failure cases.
374+ M , Fin , Fout , K = 9 , 3 , 5 , 7
375+ coefficients = self ._rs .uniform (size = (K , Fout , Fin ))
376+ f = filters .Chebyshev (self ._G , coefficients )
377+ x = self ._rs .uniform (size = (M , Fin , 2 ))
378+ self .assertRaises (ValueError , f .filter , x )
379+ x = self ._rs .uniform (size = (M , 2 , self ._G .N ))
380+ self .assertRaises (ValueError , f .filter , x )
381+ x = self ._rs .uniform (size = (2 , self ._G .N ))
382+ self .assertRaises (ValueError , f .filter , x )
383+ x = self ._rs .uniform (size = (self ._G .N ))
384+ self .assertRaises (ValueError , f .filter , x )
385+ x = self ._rs .uniform (size = (2 , M , Fin , self ._G .N ))
386+ self .assertRaises (ValueError , f .filter , x )
0 commit comments