Skip to content

Add int4 Quantization Support #21435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions keras/api/_tf_keras/keras/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars,
)
from keras.src.quantizers.quantizers import pack_int4 as pack_int4
from keras.src.quantizers.quantizers import (
quantize_and_dequantize as quantize_and_dequantize,
)
from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4
2 changes: 2 additions & 0 deletions keras/api/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars,
)
from keras.src.quantizers.quantizers import pack_int4 as pack_int4
from keras.src.quantizers.quantizers import (
quantize_and_dequantize as quantize_and_dequantize,
)
from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4
4 changes: 2 additions & 2 deletions keras/src/dtype_policies/dtype_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state

QUANTIZATION_MODES = ("int8", "float8")
QUANTIZATION_MODES = ("int8", "float8", "int4")


@keras_export(
Expand Down Expand Up @@ -350,7 +350,7 @@ def _get_quantized_dtype_policy_by_str(policy):
f"Received: policy={policy}"
)
mode, source_name = split_name
if policy.startswith("int8"):
if policy.startswith("int8") or policy.startswith("int4"):
return QuantizedDTypePolicy(mode, source_name)
elif policy.startswith("float8"):
return QuantizedFloat8DTypePolicy(mode, source_name)
Expand Down
209 changes: 187 additions & 22 deletions keras/src/layers/core/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from keras.src import activations
from keras.src import constraints
from keras.src import dtype_policies
from keras.src import initializers
from keras.src import ops
from keras.src import quantizers
Expand Down Expand Up @@ -110,9 +109,10 @@ def build(self, input_shape):
kernel_shape = (input_shape[-1], self.units)
if self.quantization_mode:
self.quantized_build(kernel_shape, mode=self.quantization_mode)
if self.quantization_mode != "int8":
# If the layer is quantized to int8, `self._kernel` will be added
# in `self._int8_build`. Therefore, we skip it here.
if self.quantization_mode not in ("int8", "int4"):
# If the layer is quantized to int8 or int4, `self._kernel` will be
# added in `self._int8_build` or `_int4_build`. Therefore, we skip
# it here.
self._kernel = self.add_weight(
name="kernel",
shape=kernel_shape,
Expand Down Expand Up @@ -182,9 +182,22 @@ def enable_lora(
"lora is already enabled. This can only be done once per layer."
)
self._tracker.unlock()
# Determine the correct input dimension for the LoRA A matrix. When
# the layer has been int4-quantized, `self._kernel` stores a *packed*
# representation whose first dimension is `ceil(input_dim/2)`. We
# saved the true, *unpacked* input dimension in `self._orig_input_dim`
# during quantization. Use it if available; otherwise fall back to the
# first dimension of `self.kernel`.
if self.quantization_mode == "int4" and hasattr(
self, "_orig_input_dim"
):
input_dim_for_lora = self._orig_input_dim
else:
input_dim_for_lora = self.kernel.shape[0]

self.lora_kernel_a = self.add_weight(
name="lora_kernel_a",
shape=(self.kernel.shape[0], rank),
shape=(input_dim_for_lora, rank),
initializer=initializers.get(a_initializer),
regularizer=self.kernel_regularizer,
)
Expand All @@ -211,7 +224,7 @@ def save_own_variables(self, store):
if self.use_bias:
target_variables.append(self.bias)
if self.quantization_mode is not None:
if self.quantization_mode == "int8":
if self.quantization_mode in ("int8", "int4"):
target_variables.append(kernel_scale)
elif self.quantization_mode == "float8":
target_variables.append(self.inputs_scale)
Expand All @@ -237,7 +250,7 @@ def load_own_variables(self, store):
if self.use_bias:
target_variables.append(self.bias)
if self.quantization_mode is not None:
if self.quantization_mode == "int8":
if self.quantization_mode in ("int8", "int4"):
target_variables.append(self.kernel_scale)
elif self.quantization_mode == "float8":
target_variables.append(self.inputs_scale)
Expand Down Expand Up @@ -315,6 +328,8 @@ def _check_load_own_variables(self, store):
def quantized_build(self, kernel_shape, mode):
if mode == "int8":
self._int8_build(kernel_shape)
elif mode == "int4":
self._int4_build(kernel_shape)
elif mode == "float8":
self._float8_build()
else:
Expand All @@ -337,6 +352,38 @@ def _int8_build(self, kernel_shape):
trainable=False,
)

def _int4_build(self, kernel_shape):
"""Build variables for int4 quantization.
`kernel_shape` is the *original* float32 kernel shape
`(input_dim, units)`. We allocate the stored kernel with rows
`ceil(input_dim/2)` because two int4 values are packed into a single
int8 byte.
"""
# Per-channel int8 quantizer for the last axis (features).
self.inputs_quantizer = quantizers.AbsMaxQuantizer(
axis=-1,
)
input_dim, output_dim = kernel_shape
packed_rows = (input_dim + 1) // 2 # ceil for odd dims

# Kernel is stored **packed**: each int8 byte contains two int4 values.
self._kernel = self.add_weight(
name="kernel",
shape=(packed_rows, output_dim),
initializer="zeros",
dtype="int8",
trainable=False,
)
# One scale per output unit (per-channel).
self.kernel_scale = self.add_weight(
name="kernel_scale",
shape=(self.units,),
initializer="ones",
trainable=False,
)
# Record original input_dim for unpacking at runtime.
self._orig_input_dim = input_dim

def _float8_build(self):
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy

Expand Down Expand Up @@ -383,6 +430,16 @@ def _float8_build(self):
def _int8_call(self, inputs, training=None):
@ops.custom_gradient
def matmul_with_inputs_gradient(inputs, kernel, kernel_scale):
"""Custom gradient function to handle the int8 quantized weights.

Automatic differentiation will not know how to handle the int8
quantized weights. So a custom gradient function is needed to
handle the int8 quantized weights.

The custom gradient function will use the dequantized kernel to
compute the gradient.
"""

def grad_fn(*args, upstream=None):
if upstream is None:
(upstream,) = args
Expand Down Expand Up @@ -415,6 +472,59 @@ def grad_fn(*args, upstream=None):
x = self.activation(x)
return x

def _int4_call(self, inputs, training=None):
"""Forward pass for int4 quantized Dense layer."""

@ops.custom_gradient
def matmul_with_inputs_gradient(inputs, kernel, kernel_scale):
"""Custom gradient function for int4 quantized weights.

Automatic differentiation will not know how to handle the
int4 quantized weights. So a custom gradient function is needed
to handle the int4 quantized weights.

The custom gradient function will use the dequantized kernel to
compute the gradient.
"""

unpacked_kernel = quantizers.unpack_int4(
kernel, self._orig_input_dim
)

def grad_fn(*args, upstream=None):
if upstream is None:
(upstream,) = args
float_kernel = ops.divide(
ops.cast(unpacked_kernel, dtype=self.compute_dtype),
kernel_scale,
)
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
return (inputs_grad, None, None)

inputs, inputs_scale = self.inputs_quantizer(inputs)
x = ops.matmul(inputs, unpacked_kernel)
x = ops.cast(x, self.compute_dtype)
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
return x, grad_fn

x = matmul_with_inputs_gradient(
inputs,
ops.convert_to_tensor(self._kernel),
ops.convert_to_tensor(self.kernel_scale),
)

if self.lora_enabled:
lora_x = ops.matmul(inputs, self.lora_kernel_a)
lora_x = ops.matmul(lora_x, self.lora_kernel_b)
x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)

# Add bias and activation
if self.bias is not None:
x = ops.add(x, self.bias)
if self.activation is not None:
x = self.activation(x)
return x

def _float8_call(self, inputs, training=None):
if self.lora_enabled:
raise NotImplementedError(
Expand Down Expand Up @@ -518,13 +628,40 @@ def quantize(self, mode, type_check=True):
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
del self._kernel
self.quantized_build(kernel_shape, mode)
if mode == "int8":
# Build variables for int8 mode
self.quantized_build(kernel_shape, mode)
self._kernel.assign(kernel_value)
self.kernel_scale.assign(kernel_scale)
elif mode == "int4":
# 1. Quantize to int4 values (still int8 dtype, range [-8,7])
kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
self._kernel,
axis=0,
value_range=(-8, 7),
dtype="int8",
to_numpy=True,
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
# 2. Pack two int4 values into a single int8 byte.
packed_kernel_value, _, orig_rows = quantizers.pack_int4(
kernel_value_int4
)
del self._kernel
# Build variables using the original kernel shape; _int4_build will
# compute the packed shape internally.
self.quantized_build(kernel_shape, mode)
# Assign packed values.
self._kernel.assign(packed_kernel_value)
self.kernel_scale.assign(kernel_scale)
elif mode == "float8":
self.quantized_build(kernel_shape, mode)
else:
raise self._quantization_mode_error(mode)

# Set new dtype policy
# Set new dtype policy only for modes that already have a policy.
if self.dtype_policy.quantization_mode is None:
from keras.src import dtype_policies # local import to avoid cycle

policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
self.dtype_policy = policy

Expand All @@ -533,17 +670,45 @@ def _get_kernel_with_merged_lora(self):
kernel_value = self._kernel
kernel_scale = self.kernel_scale
if self.lora_enabled:
# Dequantize & quantize to merge lora weights into int8 kernel
# Note that this is a lossy compression
kernel_value = ops.divide(kernel_value, kernel_scale)
kernel_value = ops.add(
kernel_value,
(self.lora_alpha / self.lora_rank)
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
)
kernel_value, kernel_scale = quantizers.abs_max_quantize(
kernel_value, axis=0, to_numpy=True
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
# For int4, `_kernel` is stored in a packed representation
# (two int4 values per int8 byte). We need to unpack it to the
# original float representation before merging with the LoRA
# update, and then pack it again after requantization.
if self.quantization_mode == "int4":
# 1) Unpack packed int4 tensor to int8 range [-8, 7].
unpacked_kernel = quantizers.unpack_int4(
kernel_value, self._orig_input_dim
)
# 2) De-scale to recover float32 kernel.
kernel_value_fp = ops.divide(unpacked_kernel, kernel_scale)
# 3) Merge LoRA delta in float32 domain.
kernel_value_fp = ops.add(
kernel_value_fp,
(self.lora_alpha / self.lora_rank)
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
)
# 4) Re-quantize to int4 (values still held in int8 dtype).
kernel_int4, kernel_scale = quantizers.abs_max_quantize(
kernel_value_fp,
axis=0,
value_range=(-8, 7),
dtype="int8",
to_numpy=True,
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
# 5) Pack the int4 values back into the compact int8 layout.
kernel_value, _, _ = quantizers.pack_int4(kernel_int4)
else:
# int8 path (regular): unpacking not required.
kernel_value = ops.divide(kernel_value, kernel_scale)
kernel_value = ops.add(
kernel_value,
(self.lora_alpha / self.lora_rank)
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
)
kernel_value, kernel_scale = quantizers.abs_max_quantize(
kernel_value, axis=0, to_numpy=True
)
kernel_scale = ops.squeeze(kernel_scale, axis=0)
return kernel_value, kernel_scale
return self.kernel, None
Loading