Skip to content

Commit c02e302

Browse files
int4 quantization support
1 parent 744b8be commit c02e302

File tree

4 files changed

+402
-27
lines changed

4 files changed

+402
-27
lines changed

keras/src/dtype_policies/dtype_policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from keras.src.api_export import keras_export
44
from keras.src.backend.common import global_state
55

6-
QUANTIZATION_MODES = ("int8", "float8")
6+
QUANTIZATION_MODES = ("int8", "float8", "int4")
77

88

99
@keras_export(
@@ -350,7 +350,7 @@ def _get_quantized_dtype_policy_by_str(policy):
350350
f"Received: policy={policy}"
351351
)
352352
mode, source_name = split_name
353-
if policy.startswith("int8"):
353+
if policy.startswith("int8") or policy.startswith("int4"):
354354
return QuantizedDTypePolicy(mode, source_name)
355355
elif policy.startswith("float8"):
356356
return QuantizedFloat8DTypePolicy(mode, source_name)

keras/src/layers/core/dense.py

Lines changed: 228 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from keras.src import activations
44
from keras.src import constraints
5-
from keras.src import dtype_policies
65
from keras.src import initializers
76
from keras.src import ops
87
from keras.src import quantizers
@@ -110,9 +109,10 @@ def build(self, input_shape):
110109
kernel_shape = (input_shape[-1], self.units)
111110
if self.quantization_mode:
112111
self.quantized_build(kernel_shape, mode=self.quantization_mode)
113-
if self.quantization_mode != "int8":
114-
# If the layer is quantized to int8, `self._kernel` will be added
115-
# in `self._int8_build`. Therefore, we skip it here.
112+
if self.quantization_mode not in ("int8", "int4"):
113+
# If the layer is quantized to int8 or int4, `self._kernel` will be
114+
# added in `self._int8_build` or `_int4_build`. Therefore, we skip
115+
# it here.
116116
self._kernel = self.add_weight(
117117
name="kernel",
118118
shape=kernel_shape,
@@ -182,9 +182,22 @@ def enable_lora(
182182
"lora is already enabled. This can only be done once per layer."
183183
)
184184
self._tracker.unlock()
185+
# Determine the correct input dimension for the LoRA A matrix. When
186+
# the layer has been int4-quantized, `self._kernel` stores a *packed*
187+
# representation whose first dimension is `ceil(input_dim/2)`. We
188+
# saved the true, *unpacked* input dimension in `self._orig_input_dim`
189+
# during quantization. Use it if available; otherwise fall back to the
190+
# first dimension of `self.kernel`.
191+
if self.quantization_mode == "int4" and hasattr(
192+
self, "_orig_input_dim"
193+
):
194+
input_dim_for_lora = self._orig_input_dim
195+
else:
196+
input_dim_for_lora = self.kernel.shape[0]
197+
185198
self.lora_kernel_a = self.add_weight(
186199
name="lora_kernel_a",
187-
shape=(self.kernel.shape[0], rank),
200+
shape=(input_dim_for_lora, rank),
188201
initializer=initializers.get(a_initializer),
189202
regularizer=self.kernel_regularizer,
190203
)
@@ -211,7 +224,7 @@ def save_own_variables(self, store):
211224
if self.use_bias:
212225
target_variables.append(self.bias)
213226
if self.quantization_mode is not None:
214-
if self.quantization_mode == "int8":
227+
if self.quantization_mode in ("int8", "int4"):
215228
target_variables.append(kernel_scale)
216229
elif self.quantization_mode == "float8":
217230
target_variables.append(self.inputs_scale)
@@ -237,7 +250,7 @@ def load_own_variables(self, store):
237250
if self.use_bias:
238251
target_variables.append(self.bias)
239252
if self.quantization_mode is not None:
240-
if self.quantization_mode == "int8":
253+
if self.quantization_mode in ("int8", "int4"):
241254
target_variables.append(self.kernel_scale)
242255
elif self.quantization_mode == "float8":
243256
target_variables.append(self.inputs_scale)
@@ -315,6 +328,8 @@ def _check_load_own_variables(self, store):
315328
def quantized_build(self, kernel_shape, mode):
316329
if mode == "int8":
317330
self._int8_build(kernel_shape)
331+
elif mode == "int4":
332+
self._int4_build(kernel_shape)
318333
elif mode == "float8":
319334
self._float8_build()
320335
else:
@@ -337,6 +352,38 @@ def _int8_build(self, kernel_shape):
337352
trainable=False,
338353
)
339354

355+
def _int4_build(self, kernel_shape):
356+
"""Build variables for int4 quantization.
357+
`kernel_shape` is the *original* float32 kernel shape
358+
`(input_dim, units)`. We allocate the stored kernel with rows
359+
`ceil(input_dim/2)` because two int4 values are packed into a single
360+
int8 byte.
361+
"""
362+
# Per-channel quantizer for the last axis (features).
363+
self.inputs_quantizer = quantizers.AbsMaxQuantizer(
364+
axis=-1, value_range=(-8, 7)
365+
)
366+
input_dim, output_dim = kernel_shape
367+
packed_rows = (input_dim + 1) // 2 # ceil for odd dims
368+
369+
# Kernel is stored **packed**: each int8 byte contains two int4 values.
370+
self._kernel = self.add_weight(
371+
name="kernel",
372+
shape=(packed_rows, output_dim),
373+
initializer="zeros",
374+
dtype="int8",
375+
trainable=False,
376+
)
377+
# One scale per output unit (per-channel).
378+
self.kernel_scale = self.add_weight(
379+
name="kernel_scale",
380+
shape=(self.units,),
381+
initializer="ones",
382+
trainable=False,
383+
)
384+
# Record original input_dim for unpacking at runtime.
385+
self._orig_input_dim = input_dim
386+
340387
def _float8_build(self):
341388
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
342389

@@ -415,6 +462,49 @@ def grad_fn(*args, upstream=None):
415462
x = self.activation(x)
416463
return x
417464

465+
def _int4_call(self, inputs, training=None):
466+
"""Forward pass for int4 quantized Dense layer."""
467+
468+
@ops.custom_gradient
469+
def matmul_with_inputs_gradient(inputs, kernel, kernel_scale):
470+
unpacked_kernel = self._unpack_int4_ops(
471+
kernel, self._orig_input_dim
472+
)
473+
474+
def grad_fn(*args, upstream=None):
475+
if upstream is None:
476+
(upstream,) = args
477+
float_kernel = ops.divide(
478+
ops.cast(unpacked_kernel, dtype=self.compute_dtype),
479+
kernel_scale,
480+
)
481+
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
482+
return (inputs_grad, None, None)
483+
484+
inputs, inputs_scale = self.inputs_quantizer(inputs)
485+
x = ops.matmul(inputs, unpacked_kernel)
486+
x = ops.cast(x, self.compute_dtype)
487+
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
488+
return x, grad_fn
489+
490+
x = matmul_with_inputs_gradient(
491+
inputs,
492+
ops.convert_to_tensor(self._kernel),
493+
ops.convert_to_tensor(self.kernel_scale),
494+
)
495+
496+
if self.lora_enabled:
497+
lora_x = ops.matmul(inputs, self.lora_kernel_a)
498+
lora_x = ops.matmul(lora_x, self.lora_kernel_b)
499+
x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)
500+
501+
# Add bias and activation
502+
if self.bias is not None:
503+
x = ops.add(x, self.bias)
504+
if self.activation is not None:
505+
x = self.activation(x)
506+
return x
507+
418508
def _float8_call(self, inputs, training=None):
419509
if self.lora_enabled:
420510
raise NotImplementedError(
@@ -518,13 +608,42 @@ def quantize(self, mode, type_check=True):
518608
)
519609
kernel_scale = ops.squeeze(kernel_scale, axis=0)
520610
del self._kernel
521-
self.quantized_build(kernel_shape, mode)
522-
if mode == "int8":
611+
# Build variables for int8 mode
612+
self.quantized_build(kernel_shape, mode)
523613
self._kernel.assign(kernel_value)
524614
self.kernel_scale.assign(kernel_scale)
615+
elif mode == "int4":
616+
# 1. Quantize to int4 values (still int8 dtype, range [-8,7])
617+
kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
618+
self._kernel,
619+
axis=0,
620+
value_range=(-8, 7),
621+
dtype="int8",
622+
to_numpy=True,
623+
)
624+
kernel_scale = ops.squeeze(kernel_scale, axis=0)
625+
# 2. Pack two int4 values into a single int8 byte.
626+
packed_kernel_value, packed_shape, orig_rows = self._pack_int4_ops(
627+
kernel_value_int4
628+
)
629+
del self._kernel
630+
# Save original input dim for unpacking.
631+
self._orig_input_dim = orig_rows
632+
# Build variables using the original kernel shape; _int4_build will
633+
# compute the packed shape internally.
634+
self.quantized_build(kernel_shape, mode)
635+
# Assign packed values.
636+
self._kernel.assign(packed_kernel_value)
637+
self.kernel_scale.assign(kernel_scale)
638+
elif mode == "float8":
639+
self.quantized_build(kernel_shape, mode)
640+
else:
641+
raise self._quantization_mode_error(mode)
525642

526-
# Set new dtype policy
643+
# Set new dtype policy only for modes that already have a policy.
527644
if self.dtype_policy.quantization_mode is None:
645+
from keras.src import dtype_policies # local import to avoid cycle
646+
528647
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
529648
self.dtype_policy = policy
530649

@@ -533,17 +652,104 @@ def _get_kernel_with_merged_lora(self):
533652
kernel_value = self._kernel
534653
kernel_scale = self.kernel_scale
535654
if self.lora_enabled:
536-
# Dequantize & quantize to merge lora weights into int8 kernel
537-
# Note that this is a lossy compression
538-
kernel_value = ops.divide(kernel_value, kernel_scale)
539-
kernel_value = ops.add(
540-
kernel_value,
541-
(self.lora_alpha / self.lora_rank)
542-
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
543-
)
544-
kernel_value, kernel_scale = quantizers.abs_max_quantize(
545-
kernel_value, axis=0, to_numpy=True
546-
)
547-
kernel_scale = ops.squeeze(kernel_scale, axis=0)
655+
# For int4, `_kernel` is stored in a packed representation
656+
# (two int4 values per int8 byte). We need to unpack it to the
657+
# original float representation before merging with the LoRA
658+
# update, and then pack it again after requantization.
659+
if self.quantization_mode == "int4":
660+
# 1) Unpack packed int4 tensor to int8 range [-8, 7].
661+
unpacked_kernel = self._unpack_int4_ops(
662+
kernel_value, self._orig_input_dim
663+
)
664+
# 2) De-scale to recover float32 kernel.
665+
kernel_value_fp = ops.divide(unpacked_kernel, kernel_scale)
666+
# 3) Merge LoRA delta in float32 domain.
667+
kernel_value_fp = ops.add(
668+
kernel_value_fp,
669+
(self.lora_alpha / self.lora_rank)
670+
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
671+
)
672+
# 4) Re-quantize to int4 (values still held in int8 dtype).
673+
kernel_int4, kernel_scale = quantizers.abs_max_quantize(
674+
kernel_value_fp,
675+
axis=0,
676+
value_range=(-8, 7),
677+
dtype="int8",
678+
to_numpy=True,
679+
)
680+
kernel_scale = ops.squeeze(kernel_scale, axis=0)
681+
# 5) Pack the int4 values back into the compact int8 layout.
682+
kernel_value, _, _ = self._pack_int4_ops(kernel_int4)
683+
else:
684+
# int8 path (regular): unpacking not required.
685+
kernel_value = ops.divide(kernel_value, kernel_scale)
686+
kernel_value = ops.add(
687+
kernel_value,
688+
(self.lora_alpha / self.lora_rank)
689+
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
690+
)
691+
kernel_value, kernel_scale = quantizers.abs_max_quantize(
692+
kernel_value, axis=0, to_numpy=True
693+
)
694+
kernel_scale = ops.squeeze(kernel_scale, axis=0)
548695
return kernel_value, kernel_scale
549696
return self.kernel, None
697+
698+
def _pack_int4_ops(self, arr):
699+
"""Pack an int4 tensor into an int8 tensor with packed nibbles.
700+
701+
Accepts a Keras-compatible tensor. The input values must already be int8
702+
in the signed range ``[-8, 7]`` and represent the desired int4 values.
703+
Packing is performed along axis 0:
704+
705+
* For every two consecutive rows, the **low nibble** of the output byte
706+
stores the value from the first row, and the **high nibble** stores
707+
the value from the second row.
708+
709+
Returns a tuple ``(packed, packed_shape, orig_rows)`` where ``packed``
710+
is the packed ``int8`` tensor, ``packed_shape`` is its shape, and
711+
``orig_rows`` is the original (unpacked) row count prior to any padding
712+
that may have been inserted when an odd number of rows is supplied.
713+
"""
714+
if arr.dtype != "int8":
715+
raise TypeError("Expected int8 tensor for packing")
716+
717+
shape = ops.shape(arr)
718+
rows, cols = shape[0], shape[1]
719+
720+
orig_rows = rows
721+
if rows % 2 == 1:
722+
padding_row = ops.zeros((1, cols), dtype="int8")
723+
arr = ops.concatenate([arr, padding_row], axis=0)
724+
rows += 1
725+
726+
# Map signed [-8,7] to unsigned 4-bit two's complement (0..15)
727+
arr_u = ops.where(arr < 0, arr + 16, arr)
728+
arr_u = ops.cast(arr_u, "uint8")
729+
arr_u = ops.reshape(arr_u, (rows // 2, 2, cols))
730+
low = arr_u[:, 0, :]
731+
high = arr_u[:, 1, :]
732+
packed = ops.bitwise_or(ops.left_shift(high, 4), low)
733+
packed = ops.cast(packed, "int8")
734+
return packed, ops.shape(packed), orig_rows
735+
736+
@staticmethod
737+
def _unpack_int4_ops(packed, orig_rows):
738+
"""Unpack packed int4 tensor (ops) to int8 [-8,7]."""
739+
# Bitwise operations work element-wise.
740+
low = ops.bitwise_and(packed, 0x0F)
741+
high = ops.right_shift(packed, 4)
742+
high = ops.bitwise_and(high, 0x0F)
743+
744+
def _to_signed(x):
745+
return ops.where(x < 8, x, ops.subtract(x, 16))
746+
747+
low = _to_signed(low)
748+
high = _to_signed(high)
749+
750+
# Interleave rows back: stacked shape (2, packed_rows, cols)
751+
stacked = ops.stack([low, high], axis=1) # (pairs, 2, cols)
752+
unpacked_full = ops.reshape(stacked, (-1, stacked.shape[-1]))
753+
# Remove potential padded row.
754+
unpacked = unpacked_full[:orig_rows, :]
755+
return unpacked

0 commit comments

Comments
 (0)