Skip to content

Commit c1a58b7

Browse files
Added tests for int4 packing logic
1 parent 71c116a commit c1a58b7

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

keras/src/quantizers/quantizers_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,46 @@ def test_quantize_and_dequantize(self):
107107
# A loose assertion due to an expected quantization error
108108
self.assertAllClose(qdq_values, values, atol=5e-1)
109109

110+
@parameterized.named_parameters(
111+
("even_rows", (4, 5), 0),
112+
("odd_rows", (5, 5), 0),
113+
("even_rows_axis_1", (4, 6), 1),
114+
("odd_rows_axis_1", (4, 7), 1),
115+
("3d_even_rows_axis_0", (4, 5, 3), 0),
116+
("3d_odd_rows_axis_0", (5, 5, 3), 0),
117+
("3d_even_rows_axis_1", (4, 6, 3), 1),
118+
("3d_odd_rows_axis_1", (4, 7, 3), 1),
119+
("3d_even_rows_axis_2", (4, 5, 6), 2),
120+
("3d_odd_rows_axis_2", (4, 5, 7), 2),
121+
("4d_odd_rows_axis_0", (2, 3, 5, 4), 0),
122+
("4d_odd_rows_axis_1", (2, 3, 5, 4), 1),
123+
("4d_odd_rows_axis_2", (2, 3, 5, 4), 2),
124+
("4d_odd_rows_axis_3", (2, 3, 5, 4), 3),
125+
("4d_even_rows_axis_0", (2, 4, 5, 4), 0),
126+
("4d_even_rows_axis_1", (2, 4, 5, 4), 1),
127+
("4d_even_rows_axis_2", (2, 4, 5, 4), 2),
128+
("4d_even_rows_axis_3", (2, 4, 5, 4), 3),
129+
)
130+
def test_pack_unpack_int4(self, shape, axis):
131+
# Create a random tensor with int4 values [-8, 7]
132+
arr = ops.cast(
133+
ops.floor(random.uniform(shape, minval=-8, maxval=8)), "int8"
134+
)
135+
136+
# Pack the tensor
137+
packed, packed_shape, orig_len = quantizers.pack_int4(arr, axis=axis)
138+
139+
# Unpack the tensor
140+
unpacked = quantizers.unpack_int4(packed, orig_len, axis=axis)
141+
142+
# The unpacked tensor should be the same as the original tensor
143+
self.assertAllClose(unpacked, arr)
144+
145+
# Test the packed shape
146+
expected_packed_shape = list(shape)
147+
expected_packed_shape[axis] = (expected_packed_shape[axis] + 1) // 2
148+
self.assertEqual(list(ops.convert_to_numpy(packed_shape)), expected_packed_shape)
149+
110150
@parameterized.named_parameters(
111151
("per_tensor", None),
112152
("per_channel", -1),

0 commit comments

Comments
 (0)