Skip to content

Commit 777b5e6

Browse files
generalize int4 packing
1 parent dd11851 commit 777b5e6

File tree

2 files changed

+106
-35
lines changed

2 files changed

+106
-35
lines changed

keras/src/quantizers/quantizers.py

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype):
377377

378378

379379
@keras_export("keras.quantizers.pack_int4")
380-
def pack_int4(arr):
380+
def pack_int4(arr, axis=0):
381381
"""Pack an int4 tensor into an int8 tensor with packed nibbles.
382382
383383
Accepts a Keras-compatible tensor. The input values must already be int8
@@ -396,43 +396,72 @@ def pack_int4(arr):
396396
if arr.dtype != "int8":
397397
raise TypeError("Expected int8 tensor for packing")
398398

399-
shape = ops.shape(arr)
400-
rows, cols = shape[0], shape[1]
399+
rank = arr.shape.rank
400+
# 1. Bring `axis` to the front.
401+
perm = [axis] + [i for i in range(rank) if i != axis]
402+
inv_perm = [perm.index(i) for i in range(rank)]
403+
transposed = ops.transpose(arr, perm)
401404

402-
orig_rows = rows
403-
if rows % 2 == 1:
404-
padding_row = ops.zeros((1, cols), dtype="int8")
405-
arr = ops.concatenate([arr, padding_row], axis=0)
406-
rows += 1
405+
# 2. Pad to even length.
406+
rows = ops.shape(transposed)[0]
407+
needs_pad = ops.equal(ops.mod(rows, 2), 1)
407408

408-
# Map signed [-8,7] to unsigned 4-bit two's complement (0..15)
409-
arr_u = ops.where(arr < 0, arr + 16, arr)
410-
arr_u = ops.cast(arr_u, "uint8")
411-
arr_u = ops.reshape(arr_u, (rows // 2, 2, cols))
412-
low = arr_u[:, 0, :]
413-
high = arr_u[:, 1, :]
414-
packed = ops.bitwise_or(ops.left_shift(high, 4), low)
415-
packed = ops.cast(packed, "int8")
416-
return packed, ops.shape(packed), orig_rows
409+
def _pad(x):
410+
pad_shape = ops.concatenate(
411+
[ops.array([1]), ops.array(ops.shape(x)[1:])], axis=0
412+
)
413+
pad_row = ops.zeros(pad_shape, dtype="int8")
414+
return ops.concatenate([x, pad_row], axis=0)
415+
416+
transposed = ops.cond(
417+
needs_pad, lambda: _pad(transposed), lambda: transposed
418+
)
419+
rows_padded = ops.shape(transposed)[0]
420+
421+
# 3-4. Group in pairs and pack.
422+
flat_tail = ops.reshape(transposed, (rows_padded // 2, 2, -1))
423+
low = flat_tail[:, 0, :]
424+
high = flat_tail[:, 1, :]
425+
low_u = ops.where(low < 0, low + 16, low)
426+
high_u = ops.where(high < 0, high + 16, high)
427+
packed_flat = ops.bitwise_or(low_u, ops.left_shift(high_u, 4))
428+
packed_flat = ops.cast(packed_flat, "int8")
429+
430+
# 5-6. Restore shape.
431+
packed = ops.reshape(
432+
packed_flat,
433+
ops.concatenate(
434+
[
435+
ops.expand_dims(rows_padded // 2, 0),
436+
ops.array(ops.shape(transposed)[1:]),
437+
],
438+
axis=0,
439+
),
440+
)
441+
packed = ops.transpose(packed, inv_perm) # back to original order
442+
orig_len = rows # number of slices before padding
443+
return packed, ops.shape(packed), orig_len
417444

418445

419446
@keras_export("keras.quantizers.unpack_int4")
420-
def unpack_int4(packed, orig_rows):
447+
def unpack_int4(packed, orig_len, axis=0):
421448
"""Unpack packed int4 tensor (ops) to int8 [-8,7]."""
422-
# Bitwise operations work element-wise.
423-
low = ops.bitwise_and(packed, 0x0F)
424-
high = ops.right_shift(packed, 4)
425-
high = ops.bitwise_and(high, 0x0F)
426-
427-
def _to_signed(x):
428-
return ops.where(x < 8, x, ops.subtract(x, 16))
429-
430-
low = _to_signed(low)
431-
high = _to_signed(high)
432-
433-
# Interleave rows back: stacked shape (2, packed_rows, cols)
434-
stacked = ops.stack([low, high], axis=1) # (pairs, 2, cols)
435-
unpacked_full = ops.reshape(stacked, (-1, stacked.shape[-1]))
436-
# Remove potential padded row.
437-
unpacked = unpacked_full[:orig_rows, :]
438-
return unpacked
449+
rank = packed.shape.rank
450+
perm = [axis] + [i for i in range(rank) if i != axis]
451+
inv_perm = [perm.index(i) for i in range(rank)]
452+
transposed = ops.transpose(packed, perm)
453+
454+
# 1. Split nibbles.
455+
low = ops.bitwise_and(transposed, 0x0F)
456+
high = ops.bitwise_and(ops.right_shift(transposed, 4), 0x0F)
457+
to_signed = lambda x: ops.where(x < 8, x, x - 16)
458+
low = to_signed(low)
459+
high = to_signed(high)
460+
461+
# 2. Interleave.
462+
stacked = ops.stack([low, high], axis=1) # (pairs, 2, …)
463+
unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(transposed)[1:]))
464+
465+
# 3. Remove possible padding and restore layout.
466+
unpacked = unpacked[:orig_len, ...]
467+
return ops.transpose(unpacked, inv_perm)

keras/src/quantizers/quantizers_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,48 @@ def test_quantize_and_dequantize(self):
107107
# A loose assertion due to an expected quantization error
108108
self.assertAllClose(qdq_values, values, atol=5e-1)
109109

110+
@parameterized.named_parameters(
111+
("even_rows", (4, 5), 0),
112+
("odd_rows", (5, 5), 0),
113+
("even_rows_axis_1", (4, 6), 1),
114+
("odd_rows_axis_1", (4, 7), 1),
115+
("3d_even_rows_axis_0", (4, 5, 3), 0),
116+
("3d_odd_rows_axis_0", (5, 5, 3), 0),
117+
("3d_even_rows_axis_1", (4, 6, 3), 1),
118+
("3d_odd_rows_axis_1", (4, 7, 3), 1),
119+
("3d_even_rows_axis_2", (4, 5, 6), 2),
120+
("3d_odd_rows_axis_2", (4, 5, 7), 2),
121+
("4d_odd_rows_axis_0", (2, 3, 5, 4), 0),
122+
("4d_odd_rows_axis_1", (2, 3, 5, 4), 1),
123+
("4d_odd_rows_axis_2", (2, 3, 5, 4), 2),
124+
("4d_odd_rows_axis_3", (2, 3, 5, 4), 3),
125+
("4d_even_rows_axis_0", (2, 4, 5, 4), 0),
126+
("4d_even_rows_axis_1", (2, 4, 5, 4), 1),
127+
("4d_even_rows_axis_2", (2, 4, 5, 4), 2),
128+
("4d_even_rows_axis_3", (2, 4, 5, 4), 3),
129+
)
130+
def test_pack_unpack_int4(self, shape, axis):
131+
# Create a random tensor with int4 values [-8, 7]
132+
arr = ops.cast(
133+
ops.floor(random.uniform(shape, minval=-8, maxval=8)), "int8"
134+
)
135+
136+
# Pack the tensor
137+
packed, packed_shape, orig_len = quantizers.pack_int4(arr, axis=axis)
138+
139+
# Unpack the tensor
140+
unpacked = quantizers.unpack_int4(packed, orig_len, axis=axis)
141+
142+
# The unpacked tensor should be the same as the original tensor
143+
self.assertAllClose(unpacked, arr)
144+
145+
# Test the packed shape
146+
expected_packed_shape = list(shape)
147+
expected_packed_shape[axis] = (expected_packed_shape[axis] + 1) // 2
148+
self.assertEqual(
149+
list(ops.convert_to_numpy(packed_shape)), expected_packed_shape
150+
)
151+
110152
@parameterized.named_parameters(
111153
("per_tensor", None),
112154
("per_channel", -1),

0 commit comments

Comments
 (0)