@@ -26,7 +26,7 @@ def test_constant_pad(
2626 x = np .random .normal (size = size ).astype (floatX )
2727 expected = np .pad (x , pad_width , mode = "constant" , constant_values = constant )
2828 z = pad (x , pad_width , mode = "constant" , constant_values = constant )
29- assert z .pad_mode == "constant"
29+ assert z .owner . op . pad_mode == "constant"
3030
3131 f = pytensor .function ([], z , mode = "FAST_COMPILE" )
3232
@@ -43,7 +43,7 @@ def test_edge_pad(size: tuple, pad_width: int | tuple[int, ...]):
4343 x = np .random .normal (size = size ).astype (floatX )
4444 expected = np .pad (x , pad_width , mode = "edge" )
4545 z = pad (x , pad_width , mode = "edge" )
46- assert z .pad_mode == "edge"
46+ assert z .owner . op . pad_mode == "edge"
4747
4848 f = pytensor .function ([], z , mode = "FAST_COMPILE" )
4949
@@ -65,7 +65,7 @@ def test_linear_ramp_pad(
6565 x = np .random .normal (size = size ).astype (floatX )
6666 expected = np .pad (x , pad_width , mode = "linear_ramp" , end_values = end_values )
6767 z = pad (x , pad_width , mode = "linear_ramp" , end_values = end_values )
68- assert z .pad_mode == "linear_ramp"
68+ assert z .owner . op . pad_mode == "linear_ramp"
6969
7070 f = pytensor .function ([], z , mode = "FAST_COMPILE" )
7171
@@ -89,8 +89,7 @@ def test_stat_pad(
8989 x = np .random .normal (size = size ).astype (floatX )
9090 expected = np .pad (x , pad_width , mode = stat , stat_length = stat_length )
9191 z = pad (x , pad_width , mode = stat , stat_length = stat_length )
92- assert z .pad_mode == stat
93- assert z .stat_length_input == (stat_length is not None )
92+ assert z .owner .op .pad_mode == stat
9493
9594 f = pytensor .function ([], z , mode = "FAST_COMPILE" )
9695
@@ -107,7 +106,7 @@ def test_wrap_pad(size: tuple, pad_width: int | tuple[int, ...]):
107106 x = np .random .normal (size = size ).astype (floatX )
108107 expected = np .pad (x , pad_width , mode = "wrap" )
109108 z = pad (x , pad_width , mode = "wrap" )
110- assert z .pad_mode == "wrap"
109+ assert z .owner . op . pad_mode == "wrap"
111110 f = pytensor .function ([], z , mode = "FAST_COMPILE" )
112111
113112 np .testing .assert_allclose (expected , f (), atol = ATOL , rtol = RTOL )
@@ -128,7 +127,50 @@ def test_symmetric_pad(size, pad_width, reflect_type):
128127 x = np .random .normal (size = size ).astype (floatX )
129128 expected = np .pad (x , pad_width , mode = "symmetric" , reflect_type = reflect_type )
130129 z = pad (x , pad_width , mode = "symmetric" , reflect_type = reflect_type )
131- assert z .pad_mode == "symmetric"
130+ assert z .owner .op .pad_mode == "symmetric"
131+ f = pytensor .function ([], z , mode = "FAST_COMPILE" )
132+
133+ np .testing .assert_allclose (expected , f (), atol = ATOL , rtol = RTOL )
134+
135+
136+ @pytest .mark .parametrize (
137+ "mode" ,
138+ [
139+ "constant" ,
140+ "edge" ,
141+ "linear_ramp" ,
142+ "wrap" ,
143+ "symmetric" ,
144+ "mean" ,
145+ "maximum" ,
146+ "minimum" ,
147+ ],
148+ )
149+ @pytest .mark .parametrize ("padding" , ["symmetric" , "asymmetric" ])
150+ def test_nd_padding (mode , padding ):
151+ rng = np .random .default_rng ()
152+ n = rng .integers (3 , 10 )
153+ if padding == "symmetric" :
154+ pad_width = [(i , i ) for i in rng .integers (1 , 5 , size = n )]
155+ stat_length = [(i , i ) for i in rng .integers (1 , 5 , size = n )]
156+ else :
157+ pad_width = rng .integers (1 , 5 , size = (n , 2 )).tolist ()
158+ stat_length = rng .integers (1 , 5 , size = (n , 2 )).tolist ()
159+
160+ test_kwargs = {
161+ "constant" : {"constant_values" : 0 },
162+ "linear_ramp" : {"end_values" : 0 },
163+ "maximum" : {"stat_length" : stat_length },
164+ "mean" : {"stat_length" : stat_length },
165+ "minimum" : {"stat_length" : stat_length },
166+ "reflect" : {"reflect_type" : "even" },
167+ "symmetric" : {"reflect_type" : "even" },
168+ }
169+
170+ x = np .random .normal (size = (2 ,) * n ).astype (floatX )
171+ kwargs = test_kwargs .get (mode , {})
172+ expected = np .pad (x , pad_width , mode = mode , ** kwargs )
173+ z = pad (x , pad_width , mode = mode , ** kwargs )
132174 f = pytensor .function ([], z , mode = "FAST_COMPILE" )
133175
134176 np .testing .assert_allclose (expected , f (), atol = ATOL , rtol = RTOL )
0 commit comments