Skip to content

Commit 71c116a

Browse files
generalize int4 packing
1 parent dd11851 commit 71c116a

File tree

1 file changed

+64
-35
lines changed

1 file changed

+64
-35
lines changed

keras/src/quantizers/quantizers.py

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype):
377377

378378

379379
@keras_export("keras.quantizers.pack_int4")
380-
def pack_int4(arr):
380+
def pack_int4(arr, axis=0):
381381
"""Pack an int4 tensor into an int8 tensor with packed nibbles.
382382
383383
Accepts a Keras-compatible tensor. The input values must already be int8
@@ -396,43 +396,72 @@ def pack_int4(arr):
396396
if arr.dtype != "int8":
397397
raise TypeError("Expected int8 tensor for packing")
398398

399-
shape = ops.shape(arr)
400-
rows, cols = shape[0], shape[1]
399+
rank = arr.shape.rank
400+
# 1. Bring `axis` to the front.
401+
perm = [axis] + [i for i in range(rank) if i != axis]
402+
inv_perm = [perm.index(i) for i in range(rank)]
403+
transposed = ops.transpose(arr, perm)
401404

402-
orig_rows = rows
403-
if rows % 2 == 1:
404-
padding_row = ops.zeros((1, cols), dtype="int8")
405-
arr = ops.concatenate([arr, padding_row], axis=0)
406-
rows += 1
405+
# 2. Pad to even length.
406+
rows = ops.shape(transposed)[0]
407+
needs_pad = ops.equal(ops.mod(rows, 2), 1)
407408

408-
# Map signed [-8,7] to unsigned 4-bit two's complement (0..15)
409-
arr_u = ops.where(arr < 0, arr + 16, arr)
410-
arr_u = ops.cast(arr_u, "uint8")
411-
arr_u = ops.reshape(arr_u, (rows // 2, 2, cols))
412-
low = arr_u[:, 0, :]
413-
high = arr_u[:, 1, :]
414-
packed = ops.bitwise_or(ops.left_shift(high, 4), low)
415-
packed = ops.cast(packed, "int8")
416-
return packed, ops.shape(packed), orig_rows
409+
def _pad(x):
410+
pad_shape = ops.concatenate(
411+
[ops.array([1]), ops.array(ops.shape(x)[1:])], axis=0
412+
)
413+
pad_row = ops.zeros(pad_shape, dtype="int8")
414+
return ops.concatenate([x, pad_row], axis=0)
415+
416+
transposed = ops.cond(
417+
needs_pad, lambda: _pad(transposed), lambda: transposed
418+
)
419+
rows_padded = ops.shape(transposed)[0]
420+
421+
# 3-4. Group in pairs and pack.
422+
flat_tail = ops.reshape(transposed, (rows_padded // 2, 2, -1))
423+
low = flat_tail[:, 0, :]
424+
high = flat_tail[:, 1, :]
425+
low_u = ops.where(low < 0, low + 16, low)
426+
high_u = ops.where(high < 0, high + 16, high)
427+
packed_flat = ops.bitwise_or(low_u, ops.left_shift(high_u, 4))
428+
packed_flat = ops.cast(packed_flat, "int8")
429+
430+
# 5-6. Restore shape.
431+
packed = ops.reshape(
432+
packed_flat,
433+
ops.concatenate(
434+
[
435+
ops.expand_dims(rows_padded // 2, 0),
436+
ops.array(ops.shape(transposed)[1:]),
437+
],
438+
axis=0,
439+
),
440+
)
441+
packed = ops.transpose(packed, inv_perm) # back to original order
442+
orig_len = rows # number of slices before padding
443+
return packed, ops.shape(packed), orig_len
417444

418445

419446
@keras_export("keras.quantizers.unpack_int4")
420-
def unpack_int4(packed, orig_rows):
447+
def unpack_int4(packed, orig_len, axis=0):
421448
"""Unpack packed int4 tensor (ops) to int8 [-8,7]."""
422-
# Bitwise operations work element-wise.
423-
low = ops.bitwise_and(packed, 0x0F)
424-
high = ops.right_shift(packed, 4)
425-
high = ops.bitwise_and(high, 0x0F)
426-
427-
def _to_signed(x):
428-
return ops.where(x < 8, x, ops.subtract(x, 16))
429-
430-
low = _to_signed(low)
431-
high = _to_signed(high)
432-
433-
# Interleave rows back: stacked shape (2, packed_rows, cols)
434-
stacked = ops.stack([low, high], axis=1) # (pairs, 2, cols)
435-
unpacked_full = ops.reshape(stacked, (-1, stacked.shape[-1]))
436-
# Remove potential padded row.
437-
unpacked = unpacked_full[:orig_rows, :]
438-
return unpacked
449+
rank = packed.shape.rank
450+
perm = [axis] + [i for i in range(rank) if i != axis]
451+
inv_perm = [perm.index(i) for i in range(rank)]
452+
transposed = ops.transpose(packed, perm)
453+
454+
# 1. Split nibbles.
455+
low = ops.bitwise_and(transposed, 0x0F)
456+
high = ops.bitwise_and(ops.right_shift(transposed, 4), 0x0F)
457+
to_signed = lambda x: ops.where(x < 8, x, x - 16)
458+
low = to_signed(low)
459+
high = to_signed(high)
460+
461+
# 2. Interleave.
462+
stacked = ops.stack([low, high], axis=1) # (pairs, 2, …)
463+
unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(transposed)[1:]))
464+
465+
# 3. Remove possible padding and restore layout.
466+
unpacked = unpacked[:orig_len, ...]
467+
return ops.transpose(unpacked, inv_perm)

0 commit comments

Comments
 (0)