Skip to content

Commit dd11851

Browse files
refactor packing utils into quantizers
1 parent c02e302 commit dd11851

File tree

5 files changed

+72
-63
lines changed

5 files changed

+72
-63
lines changed

keras/api/_tf_keras/keras/quantizers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from keras.src.quantizers.quantizers import (
2020
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars,
2121
)
22+
from keras.src.quantizers.quantizers import pack_int4 as pack_int4
2223
from keras.src.quantizers.quantizers import (
2324
quantize_and_dequantize as quantize_and_dequantize,
2425
)
26+
from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4

keras/api/quantizers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from keras.src.quantizers.quantizers import (
2020
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars,
2121
)
22+
from keras.src.quantizers.quantizers import pack_int4 as pack_int4
2223
from keras.src.quantizers.quantizers import (
2324
quantize_and_dequantize as quantize_and_dequantize,
2425
)
26+
from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4

keras/src/layers/core/dense.py

Lines changed: 4 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def _int4_call(self, inputs, training=None):
467467

468468
@ops.custom_gradient
469469
def matmul_with_inputs_gradient(inputs, kernel, kernel_scale):
470-
unpacked_kernel = self._unpack_int4_ops(
470+
unpacked_kernel = quantizers.unpack_int4(
471471
kernel, self._orig_input_dim
472472
)
473473

@@ -623,7 +623,7 @@ def quantize(self, mode, type_check=True):
623623
)
624624
kernel_scale = ops.squeeze(kernel_scale, axis=0)
625625
# 2. Pack two int4 values into a single int8 byte.
626-
packed_kernel_value, packed_shape, orig_rows = self._pack_int4_ops(
626+
packed_kernel_value, _, orig_rows = quantizers.pack_int4(
627627
kernel_value_int4
628628
)
629629
del self._kernel
@@ -658,7 +658,7 @@ def _get_kernel_with_merged_lora(self):
658658
# update, and then pack it again after requantization.
659659
if self.quantization_mode == "int4":
660660
# 1) Unpack packed int4 tensor to int8 range [-8, 7].
661-
unpacked_kernel = self._unpack_int4_ops(
661+
unpacked_kernel = quantizers.unpack_int4(
662662
kernel_value, self._orig_input_dim
663663
)
664664
# 2) De-scale to recover float32 kernel.
@@ -679,7 +679,7 @@ def _get_kernel_with_merged_lora(self):
679679
)
680680
kernel_scale = ops.squeeze(kernel_scale, axis=0)
681681
# 5) Pack the int4 values back into the compact int8 layout.
682-
kernel_value, _, _ = self._pack_int4_ops(kernel_int4)
682+
kernel_value, _, _ = quantizers.pack_int4(kernel_int4)
683683
else:
684684
# int8 path (regular): unpacking not required.
685685
kernel_value = ops.divide(kernel_value, kernel_scale)
@@ -694,62 +694,3 @@ def _get_kernel_with_merged_lora(self):
694694
kernel_scale = ops.squeeze(kernel_scale, axis=0)
695695
return kernel_value, kernel_scale
696696
return self.kernel, None
697-
698-
def _pack_int4_ops(self, arr):
699-
"""Pack an int4 tensor into an int8 tensor with packed nibbles.
700-
701-
Accepts a Keras-compatible tensor. The input values must already be int8
702-
in the signed range ``[-8, 7]`` and represent the desired int4 values.
703-
Packing is performed along axis 0:
704-
705-
* For every two consecutive rows, the **low nibble** of the output byte
706-
stores the value from the first row, and the **high nibble** stores
707-
the value from the second row.
708-
709-
Returns a tuple ``(packed, packed_shape, orig_rows)`` where ``packed``
710-
is the packed ``int8`` tensor, ``packed_shape`` is its shape, and
711-
``orig_rows`` is the original (unpacked) row count prior to any padding
712-
that may have been inserted when an odd number of rows is supplied.
713-
"""
714-
if arr.dtype != "int8":
715-
raise TypeError("Expected int8 tensor for packing")
716-
717-
shape = ops.shape(arr)
718-
rows, cols = shape[0], shape[1]
719-
720-
orig_rows = rows
721-
if rows % 2 == 1:
722-
padding_row = ops.zeros((1, cols), dtype="int8")
723-
arr = ops.concatenate([arr, padding_row], axis=0)
724-
rows += 1
725-
726-
# Map signed [-8,7] to unsigned 4-bit two's complement (0..15)
727-
arr_u = ops.where(arr < 0, arr + 16, arr)
728-
arr_u = ops.cast(arr_u, "uint8")
729-
arr_u = ops.reshape(arr_u, (rows // 2, 2, cols))
730-
low = arr_u[:, 0, :]
731-
high = arr_u[:, 1, :]
732-
packed = ops.bitwise_or(ops.left_shift(high, 4), low)
733-
packed = ops.cast(packed, "int8")
734-
return packed, ops.shape(packed), orig_rows
735-
736-
@staticmethod
737-
def _unpack_int4_ops(packed, orig_rows):
738-
"""Unpack packed int4 tensor (ops) to int8 [-8,7]."""
739-
# Bitwise operations work element-wise.
740-
low = ops.bitwise_and(packed, 0x0F)
741-
high = ops.right_shift(packed, 4)
742-
high = ops.bitwise_and(high, 0x0F)
743-
744-
def _to_signed(x):
745-
return ops.where(x < 8, x, ops.subtract(x, 16))
746-
747-
low = _to_signed(low)
748-
high = _to_signed(high)
749-
750-
# Interleave rows back: stacked shape (2, packed_rows, cols)
751-
stacked = ops.stack([low, high], axis=1) # (pairs, 2, cols)
752-
unpacked_full = ops.reshape(stacked, (-1, stacked.shape[-1]))
753-
# Remove potential padded row.
754-
unpacked = unpacked_full[:orig_rows, :]
755-
return unpacked

keras/src/quantizers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from keras.src.quantizers.quantizers import compute_float8_amax_history
88
from keras.src.quantizers.quantizers import compute_float8_scale
99
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
10+
from keras.src.quantizers.quantizers import pack_int4
1011
from keras.src.quantizers.quantizers import quantize_and_dequantize
12+
from keras.src.quantizers.quantizers import unpack_int4
1113
from keras.src.saving import serialization_lib
1214
from keras.src.utils.naming import to_snake_case
1315

keras/src/quantizers/quantizers.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,65 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype):
374374
# Dequantize
375375
x = ops.multiply(ops.cast(x, compute_dtype), ops.cast(scale, compute_dtype))
376376
return x
377+
378+
379+
@keras_export("keras.quantizers.pack_int4")
380+
def pack_int4(arr):
381+
"""Pack an int4 tensor into an int8 tensor with packed nibbles.
382+
383+
Accepts a Keras-compatible tensor. The input values must already be int8
384+
in the signed range ``[-8, 7]`` and represent the desired int4 values.
385+
Packing is performed along axis 0:
386+
387+
* For every two consecutive rows, the **low nibble** of the output byte
388+
stores the value from the first row, and the **high nibble** stores
389+
the value from the second row.
390+
391+
Returns a tuple ``(packed, packed_shape, orig_rows)`` where ``packed``
392+
is the packed ``int8`` tensor, ``packed_shape`` is its shape, and
393+
``orig_rows`` is the original (unpacked) row count prior to any padding
394+
that may have been inserted when an odd number of rows is supplied.
395+
"""
396+
if arr.dtype != "int8":
397+
raise TypeError("Expected int8 tensor for packing")
398+
399+
shape = ops.shape(arr)
400+
rows, cols = shape[0], shape[1]
401+
402+
orig_rows = rows
403+
if rows % 2 == 1:
404+
padding_row = ops.zeros((1, cols), dtype="int8")
405+
arr = ops.concatenate([arr, padding_row], axis=0)
406+
rows += 1
407+
408+
# Map signed [-8,7] to unsigned 4-bit two's complement (0..15)
409+
arr_u = ops.where(arr < 0, arr + 16, arr)
410+
arr_u = ops.cast(arr_u, "uint8")
411+
arr_u = ops.reshape(arr_u, (rows // 2, 2, cols))
412+
low = arr_u[:, 0, :]
413+
high = arr_u[:, 1, :]
414+
packed = ops.bitwise_or(ops.left_shift(high, 4), low)
415+
packed = ops.cast(packed, "int8")
416+
return packed, ops.shape(packed), orig_rows
417+
418+
419+
@keras_export("keras.quantizers.unpack_int4")
420+
def unpack_int4(packed, orig_rows):
421+
"""Unpack packed int4 tensor (ops) to int8 [-8,7]."""
422+
# Bitwise operations work element-wise.
423+
low = ops.bitwise_and(packed, 0x0F)
424+
high = ops.right_shift(packed, 4)
425+
high = ops.bitwise_and(high, 0x0F)
426+
427+
def _to_signed(x):
428+
return ops.where(x < 8, x, ops.subtract(x, 16))
429+
430+
low = _to_signed(low)
431+
high = _to_signed(high)
432+
433+
# Interleave rows back: stacked shape (2, packed_rows, cols)
434+
stacked = ops.stack([low, high], axis=1) # (pairs, 2, cols)
435+
unpacked_full = ops.reshape(stacked, (-1, stacked.shape[-1]))
436+
# Remove potential padded row.
437+
unpacked = unpacked_full[:orig_rows, :]
438+
return unpacked

0 commit comments

Comments
 (0)