@@ -467,7 +467,7 @@ def _int4_call(self, inputs, training=None):
467
467
468
468
@ops .custom_gradient
469
469
def matmul_with_inputs_gradient (inputs , kernel , kernel_scale ):
470
- unpacked_kernel = self . _unpack_int4_ops (
470
+ unpacked_kernel = quantizers . unpack_int4 (
471
471
kernel , self ._orig_input_dim
472
472
)
473
473
@@ -623,7 +623,7 @@ def quantize(self, mode, type_check=True):
623
623
)
624
624
kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
625
625
# 2. Pack two int4 values into a single int8 byte.
626
- packed_kernel_value , packed_shape , orig_rows = self . _pack_int4_ops (
626
+ packed_kernel_value , _ , orig_rows = quantizers . pack_int4 (
627
627
kernel_value_int4
628
628
)
629
629
del self ._kernel
@@ -658,7 +658,7 @@ def _get_kernel_with_merged_lora(self):
658
658
# update, and then pack it again after requantization.
659
659
if self .quantization_mode == "int4" :
660
660
# 1) Unpack packed int4 tensor to int8 range [-8, 7].
661
- unpacked_kernel = self . _unpack_int4_ops (
661
+ unpacked_kernel = quantizers . unpack_int4 (
662
662
kernel_value , self ._orig_input_dim
663
663
)
664
664
# 2) De-scale to recover float32 kernel.
@@ -679,7 +679,7 @@ def _get_kernel_with_merged_lora(self):
679
679
)
680
680
kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
681
681
# 5) Pack the int4 values back into the compact int8 layout.
682
- kernel_value , _ , _ = self . _pack_int4_ops (kernel_int4 )
682
+ kernel_value , _ , _ = quantizers . pack_int4 (kernel_int4 )
683
683
else :
684
684
# int8 path (regular): unpacking not required.
685
685
kernel_value = ops .divide (kernel_value , kernel_scale )
@@ -694,62 +694,3 @@ def _get_kernel_with_merged_lora(self):
694
694
kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
695
695
return kernel_value , kernel_scale
696
696
return self .kernel , None
697
-
698
- def _pack_int4_ops (self , arr ):
699
- """Pack an int4 tensor into an int8 tensor with packed nibbles.
700
-
701
- Accepts a Keras-compatible tensor. The input values must already be int8
702
- in the signed range ``[-8, 7]`` and represent the desired int4 values.
703
- Packing is performed along axis 0:
704
-
705
- * For every two consecutive rows, the **low nibble** of the output byte
706
- stores the value from the first row, and the **high nibble** stores
707
- the value from the second row.
708
-
709
- Returns a tuple ``(packed, packed_shape, orig_rows)`` where ``packed``
710
- is the packed ``int8`` tensor, ``packed_shape`` is its shape, and
711
- ``orig_rows`` is the original (unpacked) row count prior to any padding
712
- that may have been inserted when an odd number of rows is supplied.
713
- """
714
- if arr .dtype != "int8" :
715
- raise TypeError ("Expected int8 tensor for packing" )
716
-
717
- shape = ops .shape (arr )
718
- rows , cols = shape [0 ], shape [1 ]
719
-
720
- orig_rows = rows
721
- if rows % 2 == 1 :
722
- padding_row = ops .zeros ((1 , cols ), dtype = "int8" )
723
- arr = ops .concatenate ([arr , padding_row ], axis = 0 )
724
- rows += 1
725
-
726
- # Map signed [-8,7] to unsigned 4-bit two's complement (0..15)
727
- arr_u = ops .where (arr < 0 , arr + 16 , arr )
728
- arr_u = ops .cast (arr_u , "uint8" )
729
- arr_u = ops .reshape (arr_u , (rows // 2 , 2 , cols ))
730
- low = arr_u [:, 0 , :]
731
- high = arr_u [:, 1 , :]
732
- packed = ops .bitwise_or (ops .left_shift (high , 4 ), low )
733
- packed = ops .cast (packed , "int8" )
734
- return packed , ops .shape (packed ), orig_rows
735
-
736
- @staticmethod
737
- def _unpack_int4_ops (packed , orig_rows ):
738
- """Unpack packed int4 tensor (ops) to int8 [-8,7]."""
739
- # Bitwise operations work element-wise.
740
- low = ops .bitwise_and (packed , 0x0F )
741
- high = ops .right_shift (packed , 4 )
742
- high = ops .bitwise_and (high , 0x0F )
743
-
744
- def _to_signed (x ):
745
- return ops .where (x < 8 , x , ops .subtract (x , 16 ))
746
-
747
- low = _to_signed (low )
748
- high = _to_signed (high )
749
-
750
- # Interleave rows back: stacked shape (2, packed_rows, cols)
751
- stacked = ops .stack ([low , high ], axis = 1 ) # (pairs, 2, cols)
752
- unpacked_full = ops .reshape (stacked , (- 1 , stacked .shape [- 1 ]))
753
- # Remove potential padded row.
754
- unpacked = unpacked_full [:orig_rows , :]
755
- return unpacked
0 commit comments