2
2
3
3
from keras .src import activations
4
4
from keras .src import constraints
5
- from keras .src import dtype_policies
6
5
from keras .src import initializers
7
6
from keras .src import ops
8
7
from keras .src import quantizers
@@ -110,9 +109,10 @@ def build(self, input_shape):
110
109
kernel_shape = (input_shape [- 1 ], self .units )
111
110
if self .quantization_mode :
112
111
self .quantized_build (kernel_shape , mode = self .quantization_mode )
113
- if self .quantization_mode != "int8" :
114
- # If the layer is quantized to int8, `self._kernel` will be added
115
- # in `self._int8_build`. Therefore, we skip it here.
112
+ if self .quantization_mode not in ("int8" , "int4" ):
113
+ # If the layer is quantized to int8 or int4, `self._kernel` will be
114
+ # added in `self._int8_build` or `_int4_build`. Therefore, we skip
115
+ # it here.
116
116
self ._kernel = self .add_weight (
117
117
name = "kernel" ,
118
118
shape = kernel_shape ,
@@ -182,9 +182,22 @@ def enable_lora(
182
182
"lora is already enabled. This can only be done once per layer."
183
183
)
184
184
self ._tracker .unlock ()
185
+ # Determine the correct input dimension for the LoRA A matrix. When
186
+ # the layer has been int4-quantized, `self._kernel` stores a *packed*
187
+ # representation whose first dimension is `ceil(input_dim/2)`. We
188
+ # saved the true, *unpacked* input dimension in `self._orig_input_dim`
189
+ # during quantization. Use it if available; otherwise fall back to the
190
+ # first dimension of `self.kernel`.
191
+ if self .quantization_mode == "int4" and hasattr (
192
+ self , "_orig_input_dim"
193
+ ):
194
+ input_dim_for_lora = self ._orig_input_dim
195
+ else :
196
+ input_dim_for_lora = self .kernel .shape [0 ]
197
+
185
198
self .lora_kernel_a = self .add_weight (
186
199
name = "lora_kernel_a" ,
187
- shape = (self . kernel . shape [ 0 ] , rank ),
200
+ shape = (input_dim_for_lora , rank ),
188
201
initializer = initializers .get (a_initializer ),
189
202
regularizer = self .kernel_regularizer ,
190
203
)
@@ -211,7 +224,7 @@ def save_own_variables(self, store):
211
224
if self .use_bias :
212
225
target_variables .append (self .bias )
213
226
if self .quantization_mode is not None :
214
- if self .quantization_mode == "int8" :
227
+ if self .quantization_mode in ( "int8" , "int4" ) :
215
228
target_variables .append (kernel_scale )
216
229
elif self .quantization_mode == "float8" :
217
230
target_variables .append (self .inputs_scale )
@@ -237,7 +250,7 @@ def load_own_variables(self, store):
237
250
if self .use_bias :
238
251
target_variables .append (self .bias )
239
252
if self .quantization_mode is not None :
240
- if self .quantization_mode == "int8" :
253
+ if self .quantization_mode in ( "int8" , "int4" ) :
241
254
target_variables .append (self .kernel_scale )
242
255
elif self .quantization_mode == "float8" :
243
256
target_variables .append (self .inputs_scale )
@@ -315,6 +328,8 @@ def _check_load_own_variables(self, store):
315
328
def quantized_build (self , kernel_shape , mode ):
316
329
if mode == "int8" :
317
330
self ._int8_build (kernel_shape )
331
+ elif mode == "int4" :
332
+ self ._int4_build (kernel_shape )
318
333
elif mode == "float8" :
319
334
self ._float8_build ()
320
335
else :
@@ -337,6 +352,38 @@ def _int8_build(self, kernel_shape):
337
352
trainable = False ,
338
353
)
339
354
355
+ def _int4_build (self , kernel_shape ):
356
+ """Build variables for int4 quantization.
357
+ `kernel_shape` is the *original* float32 kernel shape
358
+ `(input_dim, units)`. We allocate the stored kernel with rows
359
+ `ceil(input_dim/2)` because two int4 values are packed into a single
360
+ int8 byte.
361
+ """
362
+ # Per-channel quantizer for the last axis (features).
363
+ self .inputs_quantizer = quantizers .AbsMaxQuantizer (
364
+ axis = - 1 , value_range = (- 8 , 7 )
365
+ )
366
+ input_dim , output_dim = kernel_shape
367
+ packed_rows = (input_dim + 1 ) // 2 # ceil for odd dims
368
+
369
+ # Kernel is stored **packed**: each int8 byte contains two int4 values.
370
+ self ._kernel = self .add_weight (
371
+ name = "kernel" ,
372
+ shape = (packed_rows , output_dim ),
373
+ initializer = "zeros" ,
374
+ dtype = "int8" ,
375
+ trainable = False ,
376
+ )
377
+ # One scale per output unit (per-channel).
378
+ self .kernel_scale = self .add_weight (
379
+ name = "kernel_scale" ,
380
+ shape = (self .units ,),
381
+ initializer = "ones" ,
382
+ trainable = False ,
383
+ )
384
+ # Record original input_dim for unpacking at runtime.
385
+ self ._orig_input_dim = input_dim
386
+
340
387
def _float8_build (self ):
341
388
from keras .src .dtype_policies import QuantizedFloat8DTypePolicy
342
389
@@ -415,6 +462,49 @@ def grad_fn(*args, upstream=None):
415
462
x = self .activation (x )
416
463
return x
417
464
465
+ def _int4_call (self , inputs , training = None ):
466
+ """Forward pass for int4 quantized Dense layer."""
467
+
468
+ @ops .custom_gradient
469
+ def matmul_with_inputs_gradient (inputs , kernel , kernel_scale ):
470
+ unpacked_kernel = self ._unpack_int4_ops (
471
+ kernel , self ._orig_input_dim
472
+ )
473
+
474
+ def grad_fn (* args , upstream = None ):
475
+ if upstream is None :
476
+ (upstream ,) = args
477
+ float_kernel = ops .divide (
478
+ ops .cast (unpacked_kernel , dtype = self .compute_dtype ),
479
+ kernel_scale ,
480
+ )
481
+ inputs_grad = ops .matmul (upstream , ops .transpose (float_kernel ))
482
+ return (inputs_grad , None , None )
483
+
484
+ inputs , inputs_scale = self .inputs_quantizer (inputs )
485
+ x = ops .matmul (inputs , unpacked_kernel )
486
+ x = ops .cast (x , self .compute_dtype )
487
+ x = ops .divide (x , ops .multiply (inputs_scale , kernel_scale ))
488
+ return x , grad_fn
489
+
490
+ x = matmul_with_inputs_gradient (
491
+ inputs ,
492
+ ops .convert_to_tensor (self ._kernel ),
493
+ ops .convert_to_tensor (self .kernel_scale ),
494
+ )
495
+
496
+ if self .lora_enabled :
497
+ lora_x = ops .matmul (inputs , self .lora_kernel_a )
498
+ lora_x = ops .matmul (lora_x , self .lora_kernel_b )
499
+ x = ops .add (x , (self .lora_alpha / self .lora_rank ) * lora_x )
500
+
501
+ # Add bias and activation
502
+ if self .bias is not None :
503
+ x = ops .add (x , self .bias )
504
+ if self .activation is not None :
505
+ x = self .activation (x )
506
+ return x
507
+
418
508
def _float8_call (self , inputs , training = None ):
419
509
if self .lora_enabled :
420
510
raise NotImplementedError (
@@ -518,13 +608,42 @@ def quantize(self, mode, type_check=True):
518
608
)
519
609
kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
520
610
del self ._kernel
521
- self . quantized_build ( kernel_shape , mode )
522
- if mode == "int8" :
611
+ # Build variables for int8 mode
612
+ self . quantized_build ( kernel_shape , mode )
523
613
self ._kernel .assign (kernel_value )
524
614
self .kernel_scale .assign (kernel_scale )
615
+ elif mode == "int4" :
616
+ # 1. Quantize to int4 values (still int8 dtype, range [-8,7])
617
+ kernel_value_int4 , kernel_scale = quantizers .abs_max_quantize (
618
+ self ._kernel ,
619
+ axis = 0 ,
620
+ value_range = (- 8 , 7 ),
621
+ dtype = "int8" ,
622
+ to_numpy = True ,
623
+ )
624
+ kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
625
+ # 2. Pack two int4 values into a single int8 byte.
626
+ packed_kernel_value , packed_shape , orig_rows = self ._pack_int4_ops (
627
+ kernel_value_int4
628
+ )
629
+ del self ._kernel
630
+ # Save original input dim for unpacking.
631
+ self ._orig_input_dim = orig_rows
632
+ # Build variables using the original kernel shape; _int4_build will
633
+ # compute the packed shape internally.
634
+ self .quantized_build (kernel_shape , mode )
635
+ # Assign packed values.
636
+ self ._kernel .assign (packed_kernel_value )
637
+ self .kernel_scale .assign (kernel_scale )
638
+ elif mode == "float8" :
639
+ self .quantized_build (kernel_shape , mode )
640
+ else :
641
+ raise self ._quantization_mode_error (mode )
525
642
526
- # Set new dtype policy
643
+ # Set new dtype policy only for modes that already have a policy.
527
644
if self .dtype_policy .quantization_mode is None :
645
+ from keras .src import dtype_policies # local import to avoid cycle
646
+
528
647
policy = dtype_policies .get (f"{ mode } _from_{ self .dtype_policy .name } " )
529
648
self .dtype_policy = policy
530
649
@@ -533,17 +652,104 @@ def _get_kernel_with_merged_lora(self):
533
652
kernel_value = self ._kernel
534
653
kernel_scale = self .kernel_scale
535
654
if self .lora_enabled :
536
- # Dequantize & quantize to merge lora weights into int8 kernel
537
- # Note that this is a lossy compression
538
- kernel_value = ops .divide (kernel_value , kernel_scale )
539
- kernel_value = ops .add (
540
- kernel_value ,
541
- (self .lora_alpha / self .lora_rank )
542
- * ops .matmul (self .lora_kernel_a , self .lora_kernel_b ),
543
- )
544
- kernel_value , kernel_scale = quantizers .abs_max_quantize (
545
- kernel_value , axis = 0 , to_numpy = True
546
- )
547
- kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
655
+ # For int4, `_kernel` is stored in a packed representation
656
+ # (two int4 values per int8 byte). We need to unpack it to the
657
+ # original float representation before merging with the LoRA
658
+ # update, and then pack it again after requantization.
659
+ if self .quantization_mode == "int4" :
660
+ # 1) Unpack packed int4 tensor to int8 range [-8, 7].
661
+ unpacked_kernel = self ._unpack_int4_ops (
662
+ kernel_value , self ._orig_input_dim
663
+ )
664
+ # 2) De-scale to recover float32 kernel.
665
+ kernel_value_fp = ops .divide (unpacked_kernel , kernel_scale )
666
+ # 3) Merge LoRA delta in float32 domain.
667
+ kernel_value_fp = ops .add (
668
+ kernel_value_fp ,
669
+ (self .lora_alpha / self .lora_rank )
670
+ * ops .matmul (self .lora_kernel_a , self .lora_kernel_b ),
671
+ )
672
+ # 4) Re-quantize to int4 (values still held in int8 dtype).
673
+ kernel_int4 , kernel_scale = quantizers .abs_max_quantize (
674
+ kernel_value_fp ,
675
+ axis = 0 ,
676
+ value_range = (- 8 , 7 ),
677
+ dtype = "int8" ,
678
+ to_numpy = True ,
679
+ )
680
+ kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
681
+ # 5) Pack the int4 values back into the compact int8 layout.
682
+ kernel_value , _ , _ = self ._pack_int4_ops (kernel_int4 )
683
+ else :
684
+ # int8 path (regular): unpacking not required.
685
+ kernel_value = ops .divide (kernel_value , kernel_scale )
686
+ kernel_value = ops .add (
687
+ kernel_value ,
688
+ (self .lora_alpha / self .lora_rank )
689
+ * ops .matmul (self .lora_kernel_a , self .lora_kernel_b ),
690
+ )
691
+ kernel_value , kernel_scale = quantizers .abs_max_quantize (
692
+ kernel_value , axis = 0 , to_numpy = True
693
+ )
694
+ kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
548
695
return kernel_value , kernel_scale
549
696
return self .kernel , None
697
+
698
+ def _pack_int4_ops (self , arr ):
699
+ """Pack an int4 tensor into an int8 tensor with packed nibbles.
700
+
701
+ Accepts a Keras-compatible tensor. The input values must already be int8
702
+ in the signed range ``[-8, 7]`` and represent the desired int4 values.
703
+ Packing is performed along axis 0:
704
+
705
+ * For every two consecutive rows, the **low nibble** of the output byte
706
+ stores the value from the first row, and the **high nibble** stores
707
+ the value from the second row.
708
+
709
+ Returns a tuple ``(packed, packed_shape, orig_rows)`` where ``packed``
710
+ is the packed ``int8`` tensor, ``packed_shape`` is its shape, and
711
+ ``orig_rows`` is the original (unpacked) row count prior to any padding
712
+ that may have been inserted when an odd number of rows is supplied.
713
+ """
714
+ if arr .dtype != "int8" :
715
+ raise TypeError ("Expected int8 tensor for packing" )
716
+
717
+ shape = ops .shape (arr )
718
+ rows , cols = shape [0 ], shape [1 ]
719
+
720
+ orig_rows = rows
721
+ if rows % 2 == 1 :
722
+ padding_row = ops .zeros ((1 , cols ), dtype = "int8" )
723
+ arr = ops .concatenate ([arr , padding_row ], axis = 0 )
724
+ rows += 1
725
+
726
+ # Map signed [-8,7] to unsigned 4-bit two's complement (0..15)
727
+ arr_u = ops .where (arr < 0 , arr + 16 , arr )
728
+ arr_u = ops .cast (arr_u , "uint8" )
729
+ arr_u = ops .reshape (arr_u , (rows // 2 , 2 , cols ))
730
+ low = arr_u [:, 0 , :]
731
+ high = arr_u [:, 1 , :]
732
+ packed = ops .bitwise_or (ops .left_shift (high , 4 ), low )
733
+ packed = ops .cast (packed , "int8" )
734
+ return packed , ops .shape (packed ), orig_rows
735
+
736
+ @staticmethod
737
+ def _unpack_int4_ops (packed , orig_rows ):
738
+ """Unpack packed int4 tensor (ops) to int8 [-8,7]."""
739
+ # Bitwise operations work element-wise.
740
+ low = ops .bitwise_and (packed , 0x0F )
741
+ high = ops .right_shift (packed , 4 )
742
+ high = ops .bitwise_and (high , 0x0F )
743
+
744
+ def _to_signed (x ):
745
+ return ops .where (x < 8 , x , ops .subtract (x , 16 ))
746
+
747
+ low = _to_signed (low )
748
+ high = _to_signed (high )
749
+
750
+ # Interleave rows back: stacked shape (2, packed_rows, cols)
751
+ stacked = ops .stack ([low , high ], axis = 1 ) # (pairs, 2, cols)
752
+ unpacked_full = ops .reshape (stacked , (- 1 , stacked .shape [- 1 ]))
753
+ # Remove potential padded row.
754
+ unpacked = unpacked_full [:orig_rows , :]
755
+ return unpacked
0 commit comments