@@ -107,6 +107,46 @@ def test_quantize_and_dequantize(self):
107
107
# A loose assertion due to an expected quantization error
108
108
self .assertAllClose (qdq_values , values , atol = 5e-1 )
109
109
110
+ @parameterized .named_parameters (
111
+ ("even_rows" , (4 , 5 ), 0 ),
112
+ ("odd_rows" , (5 , 5 ), 0 ),
113
+ ("even_rows_axis_1" , (4 , 6 ), 1 ),
114
+ ("odd_rows_axis_1" , (4 , 7 ), 1 ),
115
+ ("3d_even_rows_axis_0" , (4 , 5 , 3 ), 0 ),
116
+ ("3d_odd_rows_axis_0" , (5 , 5 , 3 ), 0 ),
117
+ ("3d_even_rows_axis_1" , (4 , 6 , 3 ), 1 ),
118
+ ("3d_odd_rows_axis_1" , (4 , 7 , 3 ), 1 ),
119
+ ("3d_even_rows_axis_2" , (4 , 5 , 6 ), 2 ),
120
+ ("3d_odd_rows_axis_2" , (4 , 5 , 7 ), 2 ),
121
+ ("4d_odd_rows_axis_0" , (2 , 3 , 5 , 4 ), 0 ),
122
+ ("4d_odd_rows_axis_1" , (2 , 3 , 5 , 4 ), 1 ),
123
+ ("4d_odd_rows_axis_2" , (2 , 3 , 5 , 4 ), 2 ),
124
+ ("4d_odd_rows_axis_3" , (2 , 3 , 5 , 4 ), 3 ),
125
+ ("4d_even_rows_axis_0" , (2 , 4 , 5 , 4 ), 0 ),
126
+ ("4d_even_rows_axis_1" , (2 , 4 , 5 , 4 ), 1 ),
127
+ ("4d_even_rows_axis_2" , (2 , 4 , 5 , 4 ), 2 ),
128
+ ("4d_even_rows_axis_3" , (2 , 4 , 5 , 4 ), 3 ),
129
+ )
130
+ def test_pack_unpack_int4 (self , shape , axis ):
131
+ # Create a random tensor with int4 values [-8, 7]
132
+ arr = ops .cast (
133
+ ops .floor (random .uniform (shape , minval = - 8 , maxval = 8 )), "int8"
134
+ )
135
+
136
+ # Pack the tensor
137
+ packed , packed_shape , orig_len = quantizers .pack_int4 (arr , axis = axis )
138
+
139
+ # Unpack the tensor
140
+ unpacked = quantizers .unpack_int4 (packed , orig_len , axis = axis )
141
+
142
+ # The unpacked tensor should be the same as the original tensor
143
+ self .assertAllClose (unpacked , arr )
144
+
145
+ # Test the packed shape
146
+ expected_packed_shape = list (shape )
147
+ expected_packed_shape [axis ] = (expected_packed_shape [axis ] + 1 ) // 2
148
+ self .assertEqual (list (ops .convert_to_numpy (packed_shape )), expected_packed_shape )
149
+
110
150
@parameterized .named_parameters (
111
151
("per_tensor" , None ),
112
152
("per_channel" , - 1 ),
0 commit comments