From c02e302a022198a8cd3786ed91266df3d5e3b01c Mon Sep 17 00:00:00 2001 From: Jyotinder <33001894+JyotinderSingh@users.noreply.github.com> Date: Sat, 28 Jun 2025 23:03:07 +0530 Subject: [PATCH 01/23] int4 quantization support --- keras/src/dtype_policies/dtype_policy.py | 4 +- keras/src/layers/core/dense.py | 250 +++++++++++++++++++++-- keras/src/layers/core/dense_test.py | 173 +++++++++++++++- keras/src/layers/layer.py | 2 + 4 files changed, 402 insertions(+), 27 deletions(-) diff --git a/keras/src/dtype_policies/dtype_policy.py b/keras/src/dtype_policies/dtype_policy.py index 9d37ac49f9f3..8182a1e45aa4 100644 --- a/keras/src/dtype_policies/dtype_policy.py +++ b/keras/src/dtype_policies/dtype_policy.py @@ -3,7 +3,7 @@ from keras.src.api_export import keras_export from keras.src.backend.common import global_state -QUANTIZATION_MODES = ("int8", "float8") +QUANTIZATION_MODES = ("int8", "float8", "int4") @keras_export( @@ -350,7 +350,7 @@ def _get_quantized_dtype_policy_by_str(policy): f"Received: policy={policy}" ) mode, source_name = split_name - if policy.startswith("int8"): + if policy.startswith("int8") or policy.startswith("int4"): return QuantizedDTypePolicy(mode, source_name) elif policy.startswith("float8"): return QuantizedFloat8DTypePolicy(mode, source_name) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index ed22ee264553..879ab8c28fc4 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -2,7 +2,6 @@ from keras.src import activations from keras.src import constraints -from keras.src import dtype_policies from keras.src import initializers from keras.src import ops from keras.src import quantizers @@ -110,9 +109,10 @@ def build(self, input_shape): kernel_shape = (input_shape[-1], self.units) if self.quantization_mode: self.quantized_build(kernel_shape, mode=self.quantization_mode) - if self.quantization_mode != "int8": - # If the layer is quantized to int8, `self._kernel` will be added - # in `self._int8_build`. Therefore, we skip it here. + if self.quantization_mode not in ("int8", "int4"): + # If the layer is quantized to int8 or int4, `self._kernel` will be + # added in `self._int8_build` or `_int4_build`. Therefore, we skip + # it here. self._kernel = self.add_weight( name="kernel", shape=kernel_shape, @@ -182,9 +182,22 @@ def enable_lora( "lora is already enabled. This can only be done once per layer." ) self._tracker.unlock() + # Determine the correct input dimension for the LoRA A matrix. When + # the layer has been int4-quantized, `self._kernel` stores a *packed* + # representation whose first dimension is `ceil(input_dim/2)`. We + # saved the true, *unpacked* input dimension in `self._orig_input_dim` + # during quantization. Use it if available; otherwise fall back to the + # first dimension of `self.kernel`. + if self.quantization_mode == "int4" and hasattr( + self, "_orig_input_dim" + ): + input_dim_for_lora = self._orig_input_dim + else: + input_dim_for_lora = self.kernel.shape[0] + self.lora_kernel_a = self.add_weight( name="lora_kernel_a", - shape=(self.kernel.shape[0], rank), + shape=(input_dim_for_lora, rank), initializer=initializers.get(a_initializer), regularizer=self.kernel_regularizer, ) @@ -211,7 +224,7 @@ def save_own_variables(self, store): if self.use_bias: target_variables.append(self.bias) if self.quantization_mode is not None: - if self.quantization_mode == "int8": + if self.quantization_mode in ("int8", "int4"): target_variables.append(kernel_scale) elif self.quantization_mode == "float8": target_variables.append(self.inputs_scale) @@ -237,7 +250,7 @@ def load_own_variables(self, store): if self.use_bias: target_variables.append(self.bias) if self.quantization_mode is not None: - if self.quantization_mode == "int8": + if self.quantization_mode in ("int8", "int4"): target_variables.append(self.kernel_scale) elif self.quantization_mode == "float8": target_variables.append(self.inputs_scale) @@ -315,6 +328,8 @@ def _check_load_own_variables(self, store): def quantized_build(self, kernel_shape, mode): if mode == "int8": self._int8_build(kernel_shape) + elif mode == "int4": + self._int4_build(kernel_shape) elif mode == "float8": self._float8_build() else: @@ -337,6 +352,38 @@ def _int8_build(self, kernel_shape): trainable=False, ) + def _int4_build(self, kernel_shape): + """Build variables for int4 quantization. + `kernel_shape` is the *original* float32 kernel shape + `(input_dim, units)`. We allocate the stored kernel with rows + `ceil(input_dim/2)` because two int4 values are packed into a single + int8 byte. + """ + # Per-channel quantizer for the last axis (features). + self.inputs_quantizer = quantizers.AbsMaxQuantizer( + axis=-1, value_range=(-8, 7) + ) + input_dim, output_dim = kernel_shape + packed_rows = (input_dim + 1) // 2 # ceil for odd dims + + # Kernel is stored **packed**: each int8 byte contains two int4 values. + self._kernel = self.add_weight( + name="kernel", + shape=(packed_rows, output_dim), + initializer="zeros", + dtype="int8", + trainable=False, + ) + # One scale per output unit (per-channel). + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=(self.units,), + initializer="ones", + trainable=False, + ) + # Record original input_dim for unpacking at runtime. + self._orig_input_dim = input_dim + def _float8_build(self): from keras.src.dtype_policies import QuantizedFloat8DTypePolicy @@ -415,6 +462,49 @@ def grad_fn(*args, upstream=None): x = self.activation(x) return x + def _int4_call(self, inputs, training=None): + """Forward pass for int4 quantized Dense layer.""" + + @ops.custom_gradient + def matmul_with_inputs_gradient(inputs, kernel, kernel_scale): + unpacked_kernel = self._unpack_int4_ops( + kernel, self._orig_input_dim + ) + + def grad_fn(*args, upstream=None): + if upstream is None: + (upstream,) = args + float_kernel = ops.divide( + ops.cast(unpacked_kernel, dtype=self.compute_dtype), + kernel_scale, + ) + inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel)) + return (inputs_grad, None, None) + + inputs, inputs_scale = self.inputs_quantizer(inputs) + x = ops.matmul(inputs, unpacked_kernel) + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + return x, grad_fn + + x = matmul_with_inputs_gradient( + inputs, + ops.convert_to_tensor(self._kernel), + ops.convert_to_tensor(self.kernel_scale), + ) + + if self.lora_enabled: + lora_x = ops.matmul(inputs, self.lora_kernel_a) + lora_x = ops.matmul(lora_x, self.lora_kernel_b) + x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x) + + # Add bias and activation + if self.bias is not None: + x = ops.add(x, self.bias) + if self.activation is not None: + x = self.activation(x) + return x + def _float8_call(self, inputs, training=None): if self.lora_enabled: raise NotImplementedError( @@ -518,13 +608,42 @@ def quantize(self, mode, type_check=True): ) kernel_scale = ops.squeeze(kernel_scale, axis=0) del self._kernel - self.quantized_build(kernel_shape, mode) - if mode == "int8": + # Build variables for int8 mode + self.quantized_build(kernel_shape, mode) self._kernel.assign(kernel_value) self.kernel_scale.assign(kernel_scale) + elif mode == "int4": + # 1. Quantize to int4 values (still int8 dtype, range [-8,7]) + kernel_value_int4, kernel_scale = quantizers.abs_max_quantize( + self._kernel, + axis=0, + value_range=(-8, 7), + dtype="int8", + to_numpy=True, + ) + kernel_scale = ops.squeeze(kernel_scale, axis=0) + # 2. Pack two int4 values into a single int8 byte. + packed_kernel_value, packed_shape, orig_rows = self._pack_int4_ops( + kernel_value_int4 + ) + del self._kernel + # Save original input dim for unpacking. + self._orig_input_dim = orig_rows + # Build variables using the original kernel shape; _int4_build will + # compute the packed shape internally. + self.quantized_build(kernel_shape, mode) + # Assign packed values. + self._kernel.assign(packed_kernel_value) + self.kernel_scale.assign(kernel_scale) + elif mode == "float8": + self.quantized_build(kernel_shape, mode) + else: + raise self._quantization_mode_error(mode) - # Set new dtype policy + # Set new dtype policy only for modes that already have a policy. if self.dtype_policy.quantization_mode is None: + from keras.src import dtype_policies # local import to avoid cycle + policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}") self.dtype_policy = policy @@ -533,17 +652,104 @@ def _get_kernel_with_merged_lora(self): kernel_value = self._kernel kernel_scale = self.kernel_scale if self.lora_enabled: - # Dequantize & quantize to merge lora weights into int8 kernel - # Note that this is a lossy compression - kernel_value = ops.divide(kernel_value, kernel_scale) - kernel_value = ops.add( - kernel_value, - (self.lora_alpha / self.lora_rank) - * ops.matmul(self.lora_kernel_a, self.lora_kernel_b), - ) - kernel_value, kernel_scale = quantizers.abs_max_quantize( - kernel_value, axis=0, to_numpy=True - ) - kernel_scale = ops.squeeze(kernel_scale, axis=0) + # For int4, `_kernel` is stored in a packed representation + # (two int4 values per int8 byte). We need to unpack it to the + # original float representation before merging with the LoRA + # update, and then pack it again after requantization. + if self.quantization_mode == "int4": + # 1) Unpack packed int4 tensor to int8 range [-8, 7]. + unpacked_kernel = self._unpack_int4_ops( + kernel_value, self._orig_input_dim + ) + # 2) De-scale to recover float32 kernel. + kernel_value_fp = ops.divide(unpacked_kernel, kernel_scale) + # 3) Merge LoRA delta in float32 domain. + kernel_value_fp = ops.add( + kernel_value_fp, + (self.lora_alpha / self.lora_rank) + * ops.matmul(self.lora_kernel_a, self.lora_kernel_b), + ) + # 4) Re-quantize to int4 (values still held in int8 dtype). + kernel_int4, kernel_scale = quantizers.abs_max_quantize( + kernel_value_fp, + axis=0, + value_range=(-8, 7), + dtype="int8", + to_numpy=True, + ) + kernel_scale = ops.squeeze(kernel_scale, axis=0) + # 5) Pack the int4 values back into the compact int8 layout. + kernel_value, _, _ = self._pack_int4_ops(kernel_int4) + else: + # int8 path (regular): unpacking not required. + kernel_value = ops.divide(kernel_value, kernel_scale) + kernel_value = ops.add( + kernel_value, + (self.lora_alpha / self.lora_rank) + * ops.matmul(self.lora_kernel_a, self.lora_kernel_b), + ) + kernel_value, kernel_scale = quantizers.abs_max_quantize( + kernel_value, axis=0, to_numpy=True + ) + kernel_scale = ops.squeeze(kernel_scale, axis=0) return kernel_value, kernel_scale return self.kernel, None + + def _pack_int4_ops(self, arr): + """Pack an int4 tensor into an int8 tensor with packed nibbles. + + Accepts a Keras-compatible tensor. The input values must already be int8 + in the signed range ``[-8, 7]`` and represent the desired int4 values. + Packing is performed along axis 0: + + * For every two consecutive rows, the **low nibble** of the output byte + stores the value from the first row, and the **high nibble** stores + the value from the second row. + + Returns a tuple ``(packed, packed_shape, orig_rows)`` where ``packed`` + is the packed ``int8`` tensor, ``packed_shape`` is its shape, and + ``orig_rows`` is the original (unpacked) row count prior to any padding + that may have been inserted when an odd number of rows is supplied. + """ + if arr.dtype != "int8": + raise TypeError("Expected int8 tensor for packing") + + shape = ops.shape(arr) + rows, cols = shape[0], shape[1] + + orig_rows = rows + if rows % 2 == 1: + padding_row = ops.zeros((1, cols), dtype="int8") + arr = ops.concatenate([arr, padding_row], axis=0) + rows += 1 + + # Map signed [-8,7] to unsigned 4-bit two's complement (0..15) + arr_u = ops.where(arr < 0, arr + 16, arr) + arr_u = ops.cast(arr_u, "uint8") + arr_u = ops.reshape(arr_u, (rows // 2, 2, cols)) + low = arr_u[:, 0, :] + high = arr_u[:, 1, :] + packed = ops.bitwise_or(ops.left_shift(high, 4), low) + packed = ops.cast(packed, "int8") + return packed, ops.shape(packed), orig_rows + + @staticmethod + def _unpack_int4_ops(packed, orig_rows): + """Unpack packed int4 tensor (ops) to int8 [-8,7].""" + # Bitwise operations work element-wise. + low = ops.bitwise_and(packed, 0x0F) + high = ops.right_shift(packed, 4) + high = ops.bitwise_and(high, 0x0F) + + def _to_signed(x): + return ops.where(x < 8, x, ops.subtract(x, 16)) + + low = _to_signed(low) + high = _to_signed(high) + + # Interleave rows back: stacked shape (2, packed_rows, cols) + stacked = ops.stack([low, high], axis=1) # (pairs, 2, cols) + unpacked_full = ops.reshape(stacked, (-1, stacked.shape[-1])) + # Remove potential padded row. + unpacked = unpacked_full[:orig_rows, :] + return unpacked diff --git a/keras/src/layers/core/dense_test.py b/keras/src/layers/core/dense_test.py index ba1073cd97ce..835f675f7e66 100644 --- a/keras/src/layers/core/dense_test.py +++ b/keras/src/layers/core/dense_test.py @@ -505,9 +505,9 @@ def test_quantize_invalid_mode(self, mode): @parameterized.named_parameters( ("int8", "int8_from_mixed_bfloat16", 1, 2), ("float8", "float8_from_mixed_bfloat16", 8, 0), + ("int4", "int4_from_mixed_bfloat16", 1, 2), ) @pytest.mark.requires_trainable_backend - @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") def test_quantize_dtype_argument( self, dtype, num_trainable_weights, num_non_trainable_weights ): @@ -524,7 +524,6 @@ def test_quantize_dtype_argument( ) @pytest.mark.requires_trainable_backend - @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") def test_quantize_int8_when_lora_enabled(self): # Note that saving and loading with lora_enabled and quantized are # lossy, so we use a weak correctness test for model outputs (atol=0.5). @@ -606,7 +605,6 @@ def test_quantize_int8_when_lora_enabled(self): ) @pytest.mark.requires_trainable_backend - @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") def test_quantize_float8(self): import ml_dtypes @@ -787,3 +785,172 @@ def test_quantize_float8_inference(self): y_inference = layer(x, training=False) y_training = layer(x, training=True) self.assertAllClose(y_inference, y_training) + + def test_quantize_int4(self): + """Basic correctness / serialization test for int4 quantization.""" + layer = layers.Dense(units=16) + layer.build((None, 8)) + + # Reference (float32) output. + x = np.random.random((2, 8)) + y_float = layer(x) + + # Quantize to int4 and validate kernel dtype / scale dtype. + layer.quantize("int4") + self.assertEqual( + backend.standardize_dtype(layer._kernel.dtype), + "int8", # Packed int4 values are stored as int8 + ) + self.assertEqual( + backend.standardize_dtype(layer.kernel_scale.dtype), + layer.variable_dtype, + ) + + y_quantized = layer(x) + mse = ops.mean(ops.square(y_float - y_quantized)) + self.assertLess(mse, 2e-3) # Weak correctness check + + # Check model save / load round-trip. + model = models.Sequential([layer]) + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_int4_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Check weights-only save / load round-trip. + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_int4_model.weights.h5" + ) + model.save_weights(temp_filepath) + new_model = models.Sequential([layers.Dense(units=16)]) + new_model.build((None, 8)) + new_model.quantize("int4") + new_model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + def test_quantize_int4_on_unbuilt_layer(self): + layer = layers.Dense(units=2) + with self.assertRaisesRegex( + ValueError, "Cannot quantize a layer that isn't yet built." + ): + layer.quantize("int4") + + def test_quantize_int4_on_subclass(self): + class MyDense(layers.Dense): + pass + + layer = MyDense(units=16) + layer.build((None, 8)) + with self.assertRaises(NotImplementedError): + layer.quantize("int4") + + # It should succeed when `type_check=False`. + layer.quantize("int4", type_check=False) + + def test_quantize_int4_when_already_quantized(self): + layer = layers.Dense(units=2) + layer.build((None, 2)) + layer.quantize("int4") + for m in ["int8", "float8", "int4"]: + with self.assertRaisesRegex( + ValueError, "is already quantized with dtype_policy=" + ): + layer.quantize(m) + + layer = layers.Dense(units=2, dtype="int4_from_float32") + layer.build((None, 2)) + for m in ["int8", "float8", "int4"]: + with self.assertRaisesRegex( + ValueError, "is already quantized with dtype_policy=" + ): + layer.quantize(m) + + def test_quantize_int4_by_setting_dtype_policy(self): + policy = "int4_from_float32" + expected_num_variables = 3 # bias + packed kernel + scale + layer = layers.Dense(units=2) + layer.build((None, 2)) + layer.dtype_policy = policy + self.assertLen(layer.variables, expected_num_variables) + + @pytest.mark.requires_trainable_backend + def test_quantize_int4_when_lora_enabled(self): + config = dict(units=16) + layer = layers.Dense(**config) + layer.build((None, 8)) + layer.enable_lora(4) + layer.quantize("int4") + self.assertLen(layer.trainable_weights, 3) + self.assertLen(layer.non_trainable_weights, 2) + if backend.backend() == "torch": + self.assertLen(layer.torch_params, 5) + + # Try calling fit() + init_lora_a_kernel_value = layer.lora_kernel_a.numpy() + init_lora_b_kernel_value = layer.lora_kernel_b.numpy() + x = np.random.random((64, 8)) + y = np.random.random((64, 16)) + model = models.Sequential([layer]) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y, epochs=2) + + final_lora_a_kernel_value = layer.lora_kernel_a.numpy() + final_lora_b_kernel_value = layer.lora_kernel_b.numpy() + diff_a = np.max( + np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value) + ) + diff_b = np.max( + np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value) + ) + self.assertGreater(diff_a, 0.0) + self.assertGreater(diff_b, 0.0) + + # Save & reload full model + model = models.Sequential([layer]) + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_int4_lora_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertTrue(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Save & reload weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_int4_lora_model.weights.h5" + ) + model.save_weights(temp_filepath) + new_model = models.Sequential([layers.Dense(**config)]) + new_model.build((None, 8)) + new_model.quantize("int4") + new_model.load_weights(temp_filepath) + self.assertFalse(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Try loading a normal checkpoint into a lora model + new_model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Test export and TFSMLayer reloading when using tensorflow backend + if backend.backend() == "tensorflow": + import tensorflow as tf + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + ref_input = tf.random.normal((2, 8)) + ref_output = model(ref_input) + model.export(temp_filepath, format="tf_saved_model") + reloaded_layer = export.TFSMLayer(temp_filepath) + self.assertAllClose( + reloaded_layer(ref_input), ref_output, atol=1e-7 + ) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index eaff1a8376a2..0e61edb5b111 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1312,6 +1312,8 @@ def quantized_call(self, *args, **kwargs): return self._int8_call(*args, **kwargs) elif self.quantization_mode == "float8": return self._float8_call(*args, **kwargs) + elif self.quantization_mode == "int4": + return self._int4_call(*args, **kwargs) else: raise self._quantization_mode_error(self.quantization_mode) From dd11851e0b472c82e678e1554bddcf57e77aba61 Mon Sep 17 00:00:00 2001 From: Jyotinder <33001894+JyotinderSingh@users.noreply.github.com> Date: Sun, 29 Jun 2025 16:01:42 +0530 Subject: [PATCH 02/23] refactor packing utils into quantizers --- .../_tf_keras/keras/quantizers/__init__.py | 2 + keras/api/quantizers/__init__.py | 2 + keras/src/layers/core/dense.py | 67 ++----------------- keras/src/quantizers/__init__.py | 2 + keras/src/quantizers/quantizers.py | 62 +++++++++++++++++ 5 files changed, 72 insertions(+), 63 deletions(-) diff --git a/keras/api/_tf_keras/keras/quantizers/__init__.py b/keras/api/_tf_keras/keras/quantizers/__init__.py index 6400d8faf634..48cf9b034560 100644 --- a/keras/api/_tf_keras/keras/quantizers/__init__.py +++ b/keras/api/_tf_keras/keras/quantizers/__init__.py @@ -19,6 +19,8 @@ from keras.src.quantizers.quantizers import ( fake_quant_with_min_max_vars as fake_quant_with_min_max_vars, ) +from keras.src.quantizers.quantizers import pack_int4 as pack_int4 from keras.src.quantizers.quantizers import ( quantize_and_dequantize as quantize_and_dequantize, ) +from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4 diff --git a/keras/api/quantizers/__init__.py b/keras/api/quantizers/__init__.py index 6400d8faf634..48cf9b034560 100644 --- a/keras/api/quantizers/__init__.py +++ b/keras/api/quantizers/__init__.py @@ -19,6 +19,8 @@ from keras.src.quantizers.quantizers import ( fake_quant_with_min_max_vars as fake_quant_with_min_max_vars, ) +from keras.src.quantizers.quantizers import pack_int4 as pack_int4 from keras.src.quantizers.quantizers import ( quantize_and_dequantize as quantize_and_dequantize, ) +from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4 diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index 879ab8c28fc4..7d098bffaabe 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -467,7 +467,7 @@ def _int4_call(self, inputs, training=None): @ops.custom_gradient def matmul_with_inputs_gradient(inputs, kernel, kernel_scale): - unpacked_kernel = self._unpack_int4_ops( + unpacked_kernel = quantizers.unpack_int4( kernel, self._orig_input_dim ) @@ -623,7 +623,7 @@ def quantize(self, mode, type_check=True): ) kernel_scale = ops.squeeze(kernel_scale, axis=0) # 2. Pack two int4 values into a single int8 byte. - packed_kernel_value, packed_shape, orig_rows = self._pack_int4_ops( + packed_kernel_value, _, orig_rows = quantizers.pack_int4( kernel_value_int4 ) del self._kernel @@ -658,7 +658,7 @@ def _get_kernel_with_merged_lora(self): # update, and then pack it again after requantization. if self.quantization_mode == "int4": # 1) Unpack packed int4 tensor to int8 range [-8, 7]. - unpacked_kernel = self._unpack_int4_ops( + unpacked_kernel = quantizers.unpack_int4( kernel_value, self._orig_input_dim ) # 2) De-scale to recover float32 kernel. @@ -679,7 +679,7 @@ def _get_kernel_with_merged_lora(self): ) kernel_scale = ops.squeeze(kernel_scale, axis=0) # 5) Pack the int4 values back into the compact int8 layout. - kernel_value, _, _ = self._pack_int4_ops(kernel_int4) + kernel_value, _, _ = quantizers.pack_int4(kernel_int4) else: # int8 path (regular): unpacking not required. kernel_value = ops.divide(kernel_value, kernel_scale) @@ -694,62 +694,3 @@ def _get_kernel_with_merged_lora(self): kernel_scale = ops.squeeze(kernel_scale, axis=0) return kernel_value, kernel_scale return self.kernel, None - - def _pack_int4_ops(self, arr): - """Pack an int4 tensor into an int8 tensor with packed nibbles. - - Accepts a Keras-compatible tensor. The input values must already be int8 - in the signed range ``[-8, 7]`` and represent the desired int4 values. - Packing is performed along axis 0: - - * For every two consecutive rows, the **low nibble** of the output byte - stores the value from the first row, and the **high nibble** stores - the value from the second row. - - Returns a tuple ``(packed, packed_shape, orig_rows)`` where ``packed`` - is the packed ``int8`` tensor, ``packed_shape`` is its shape, and - ``orig_rows`` is the original (unpacked) row count prior to any padding - that may have been inserted when an odd number of rows is supplied. - """ - if arr.dtype != "int8": - raise TypeError("Expected int8 tensor for packing") - - shape = ops.shape(arr) - rows, cols = shape[0], shape[1] - - orig_rows = rows - if rows % 2 == 1: - padding_row = ops.zeros((1, cols), dtype="int8") - arr = ops.concatenate([arr, padding_row], axis=0) - rows += 1 - - # Map signed [-8,7] to unsigned 4-bit two's complement (0..15) - arr_u = ops.where(arr < 0, arr + 16, arr) - arr_u = ops.cast(arr_u, "uint8") - arr_u = ops.reshape(arr_u, (rows // 2, 2, cols)) - low = arr_u[:, 0, :] - high = arr_u[:, 1, :] - packed = ops.bitwise_or(ops.left_shift(high, 4), low) - packed = ops.cast(packed, "int8") - return packed, ops.shape(packed), orig_rows - - @staticmethod - def _unpack_int4_ops(packed, orig_rows): - """Unpack packed int4 tensor (ops) to int8 [-8,7].""" - # Bitwise operations work element-wise. - low = ops.bitwise_and(packed, 0x0F) - high = ops.right_shift(packed, 4) - high = ops.bitwise_and(high, 0x0F) - - def _to_signed(x): - return ops.where(x < 8, x, ops.subtract(x, 16)) - - low = _to_signed(low) - high = _to_signed(high) - - # Interleave rows back: stacked shape (2, packed_rows, cols) - stacked = ops.stack([low, high], axis=1) # (pairs, 2, cols) - unpacked_full = ops.reshape(stacked, (-1, stacked.shape[-1])) - # Remove potential padded row. - unpacked = unpacked_full[:orig_rows, :] - return unpacked diff --git a/keras/src/quantizers/__init__.py b/keras/src/quantizers/__init__.py index dc7643e1e82a..586530204588 100644 --- a/keras/src/quantizers/__init__.py +++ b/keras/src/quantizers/__init__.py @@ -7,7 +7,9 @@ from keras.src.quantizers.quantizers import compute_float8_amax_history from keras.src.quantizers.quantizers import compute_float8_scale from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars +from keras.src.quantizers.quantizers import pack_int4 from keras.src.quantizers.quantizers import quantize_and_dequantize +from keras.src.quantizers.quantizers import unpack_int4 from keras.src.saving import serialization_lib from keras.src.utils.naming import to_snake_case diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 26ae800ce8f0..1ecffe9ebd90 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -374,3 +374,65 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype): # Dequantize x = ops.multiply(ops.cast(x, compute_dtype), ops.cast(scale, compute_dtype)) return x + + +@keras_export("keras.quantizers.pack_int4") +def pack_int4(arr): + """Pack an int4 tensor into an int8 tensor with packed nibbles. + + Accepts a Keras-compatible tensor. The input values must already be int8 + in the signed range ``[-8, 7]`` and represent the desired int4 values. + Packing is performed along axis 0: + + * For every two consecutive rows, the **low nibble** of the output byte + stores the value from the first row, and the **high nibble** stores + the value from the second row. + + Returns a tuple ``(packed, packed_shape, orig_rows)`` where ``packed`` + is the packed ``int8`` tensor, ``packed_shape`` is its shape, and + ``orig_rows`` is the original (unpacked) row count prior to any padding + that may have been inserted when an odd number of rows is supplied. + """ + if arr.dtype != "int8": + raise TypeError("Expected int8 tensor for packing") + + shape = ops.shape(arr) + rows, cols = shape[0], shape[1] + + orig_rows = rows + if rows % 2 == 1: + padding_row = ops.zeros((1, cols), dtype="int8") + arr = ops.concatenate([arr, padding_row], axis=0) + rows += 1 + + # Map signed [-8,7] to unsigned 4-bit two's complement (0..15) + arr_u = ops.where(arr < 0, arr + 16, arr) + arr_u = ops.cast(arr_u, "uint8") + arr_u = ops.reshape(arr_u, (rows // 2, 2, cols)) + low = arr_u[:, 0, :] + high = arr_u[:, 1, :] + packed = ops.bitwise_or(ops.left_shift(high, 4), low) + packed = ops.cast(packed, "int8") + return packed, ops.shape(packed), orig_rows + + +@keras_export("keras.quantizers.unpack_int4") +def unpack_int4(packed, orig_rows): + """Unpack packed int4 tensor (ops) to int8 [-8,7].""" + # Bitwise operations work element-wise. + low = ops.bitwise_and(packed, 0x0F) + high = ops.right_shift(packed, 4) + high = ops.bitwise_and(high, 0x0F) + + def _to_signed(x): + return ops.where(x < 8, x, ops.subtract(x, 16)) + + low = _to_signed(low) + high = _to_signed(high) + + # Interleave rows back: stacked shape (2, packed_rows, cols) + stacked = ops.stack([low, high], axis=1) # (pairs, 2, cols) + unpacked_full = ops.reshape(stacked, (-1, stacked.shape[-1])) + # Remove potential padded row. + unpacked = unpacked_full[:orig_rows, :] + return unpacked From 777b5e61c15c6f3204f7ab247a636be2fb0c35a8 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Sun, 29 Jun 2025 21:17:30 +0530 Subject: [PATCH 03/23] generalize int4 packing --- keras/src/quantizers/quantizers.py | 99 ++++++++++++++++--------- keras/src/quantizers/quantizers_test.py | 42 +++++++++++ 2 files changed, 106 insertions(+), 35 deletions(-) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 1ecffe9ebd90..e4f8032f9741 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -377,7 +377,7 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype): @keras_export("keras.quantizers.pack_int4") -def pack_int4(arr): +def pack_int4(arr, axis=0): """Pack an int4 tensor into an int8 tensor with packed nibbles. Accepts a Keras-compatible tensor. The input values must already be int8 @@ -396,43 +396,72 @@ def pack_int4(arr): if arr.dtype != "int8": raise TypeError("Expected int8 tensor for packing") - shape = ops.shape(arr) - rows, cols = shape[0], shape[1] + rank = arr.shape.rank + # 1. Bring `axis` to the front. + perm = [axis] + [i for i in range(rank) if i != axis] + inv_perm = [perm.index(i) for i in range(rank)] + transposed = ops.transpose(arr, perm) - orig_rows = rows - if rows % 2 == 1: - padding_row = ops.zeros((1, cols), dtype="int8") - arr = ops.concatenate([arr, padding_row], axis=0) - rows += 1 + # 2. Pad to even length. + rows = ops.shape(transposed)[0] + needs_pad = ops.equal(ops.mod(rows, 2), 1) - # Map signed [-8,7] to unsigned 4-bit two's complement (0..15) - arr_u = ops.where(arr < 0, arr + 16, arr) - arr_u = ops.cast(arr_u, "uint8") - arr_u = ops.reshape(arr_u, (rows // 2, 2, cols)) - low = arr_u[:, 0, :] - high = arr_u[:, 1, :] - packed = ops.bitwise_or(ops.left_shift(high, 4), low) - packed = ops.cast(packed, "int8") - return packed, ops.shape(packed), orig_rows + def _pad(x): + pad_shape = ops.concatenate( + [ops.array([1]), ops.array(ops.shape(x)[1:])], axis=0 + ) + pad_row = ops.zeros(pad_shape, dtype="int8") + return ops.concatenate([x, pad_row], axis=0) + + transposed = ops.cond( + needs_pad, lambda: _pad(transposed), lambda: transposed + ) + rows_padded = ops.shape(transposed)[0] + + # 3-4. Group in pairs and pack. + flat_tail = ops.reshape(transposed, (rows_padded // 2, 2, -1)) + low = flat_tail[:, 0, :] + high = flat_tail[:, 1, :] + low_u = ops.where(low < 0, low + 16, low) + high_u = ops.where(high < 0, high + 16, high) + packed_flat = ops.bitwise_or(low_u, ops.left_shift(high_u, 4)) + packed_flat = ops.cast(packed_flat, "int8") + + # 5-6. Restore shape. + packed = ops.reshape( + packed_flat, + ops.concatenate( + [ + ops.expand_dims(rows_padded // 2, 0), + ops.array(ops.shape(transposed)[1:]), + ], + axis=0, + ), + ) + packed = ops.transpose(packed, inv_perm) # back to original order + orig_len = rows # number of slices before padding + return packed, ops.shape(packed), orig_len @keras_export("keras.quantizers.unpack_int4") -def unpack_int4(packed, orig_rows): +def unpack_int4(packed, orig_len, axis=0): """Unpack packed int4 tensor (ops) to int8 [-8,7].""" - # Bitwise operations work element-wise. - low = ops.bitwise_and(packed, 0x0F) - high = ops.right_shift(packed, 4) - high = ops.bitwise_and(high, 0x0F) - - def _to_signed(x): - return ops.where(x < 8, x, ops.subtract(x, 16)) - - low = _to_signed(low) - high = _to_signed(high) - - # Interleave rows back: stacked shape (2, packed_rows, cols) - stacked = ops.stack([low, high], axis=1) # (pairs, 2, cols) - unpacked_full = ops.reshape(stacked, (-1, stacked.shape[-1])) - # Remove potential padded row. - unpacked = unpacked_full[:orig_rows, :] - return unpacked + rank = packed.shape.rank + perm = [axis] + [i for i in range(rank) if i != axis] + inv_perm = [perm.index(i) for i in range(rank)] + transposed = ops.transpose(packed, perm) + + # 1. Split nibbles. + low = ops.bitwise_and(transposed, 0x0F) + high = ops.bitwise_and(ops.right_shift(transposed, 4), 0x0F) + to_signed = lambda x: ops.where(x < 8, x, x - 16) + low = to_signed(low) + high = to_signed(high) + + # 2. Interleave. + stacked = ops.stack([low, high], axis=1) # (pairs, 2, …) + unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(transposed)[1:])) + + # 3. Remove possible padding and restore layout. + unpacked = unpacked[:orig_len, ...] + return ops.transpose(unpacked, inv_perm) diff --git a/keras/src/quantizers/quantizers_test.py b/keras/src/quantizers/quantizers_test.py index 1fc7c94df7d6..63084a91605b 100644 --- a/keras/src/quantizers/quantizers_test.py +++ b/keras/src/quantizers/quantizers_test.py @@ -107,6 +107,48 @@ def test_quantize_and_dequantize(self): # A loose assertion due to an expected quantization error self.assertAllClose(qdq_values, values, atol=5e-1) + @parameterized.named_parameters( + ("even_rows", (4, 5), 0), + ("odd_rows", (5, 5), 0), + ("even_rows_axis_1", (4, 6), 1), + ("odd_rows_axis_1", (4, 7), 1), + ("3d_even_rows_axis_0", (4, 5, 3), 0), + ("3d_odd_rows_axis_0", (5, 5, 3), 0), + ("3d_even_rows_axis_1", (4, 6, 3), 1), + ("3d_odd_rows_axis_1", (4, 7, 3), 1), + ("3d_even_rows_axis_2", (4, 5, 6), 2), + ("3d_odd_rows_axis_2", (4, 5, 7), 2), + ("4d_odd_rows_axis_0", (2, 3, 5, 4), 0), + ("4d_odd_rows_axis_1", (2, 3, 5, 4), 1), + ("4d_odd_rows_axis_2", (2, 3, 5, 4), 2), + ("4d_odd_rows_axis_3", (2, 3, 5, 4), 3), + ("4d_even_rows_axis_0", (2, 4, 5, 4), 0), + ("4d_even_rows_axis_1", (2, 4, 5, 4), 1), + ("4d_even_rows_axis_2", (2, 4, 5, 4), 2), + ("4d_even_rows_axis_3", (2, 4, 5, 4), 3), + ) + def test_pack_unpack_int4(self, shape, axis): + # Create a random tensor with int4 values [-8, 7] + arr = ops.cast( + ops.floor(random.uniform(shape, minval=-8, maxval=8)), "int8" + ) + + # Pack the tensor + packed, packed_shape, orig_len = quantizers.pack_int4(arr, axis=axis) + + # Unpack the tensor + unpacked = quantizers.unpack_int4(packed, orig_len, axis=axis) + + # The unpacked tensor should be the same as the original tensor + self.assertAllClose(unpacked, arr) + + # Test the packed shape + expected_packed_shape = list(shape) + expected_packed_shape[axis] = (expected_packed_shape[axis] + 1) // 2 + self.assertEqual( + list(ops.convert_to_numpy(packed_shape)), expected_packed_shape + ) + @parameterized.named_parameters( ("per_tensor", None), ("per_channel", -1), From 72a8cbc104e1026c3817018c1d6185a6cd825246 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Mon, 30 Jun 2025 13:37:35 +0530 Subject: [PATCH 04/23] restored pytest skip conditions --- keras/src/layers/core/dense_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/keras/src/layers/core/dense_test.py b/keras/src/layers/core/dense_test.py index 835f675f7e66..6056354e64c5 100644 --- a/keras/src/layers/core/dense_test.py +++ b/keras/src/layers/core/dense_test.py @@ -508,6 +508,7 @@ def test_quantize_invalid_mode(self, mode): ("int4", "int4_from_mixed_bfloat16", 1, 2), ) @pytest.mark.requires_trainable_backend + @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") def test_quantize_dtype_argument( self, dtype, num_trainable_weights, num_non_trainable_weights ): @@ -524,6 +525,7 @@ def test_quantize_dtype_argument( ) @pytest.mark.requires_trainable_backend + @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") def test_quantize_int8_when_lora_enabled(self): # Note that saving and loading with lora_enabled and quantized are # lossy, so we use a weak correctness test for model outputs (atol=0.5). @@ -605,6 +607,7 @@ def test_quantize_int8_when_lora_enabled(self): ) @pytest.mark.requires_trainable_backend + @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") def test_quantize_float8(self): import ml_dtypes @@ -786,6 +789,7 @@ def test_quantize_float8_inference(self): y_training = layer(x, training=True) self.assertAllClose(y_inference, y_training) + @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") def test_quantize_int4(self): """Basic correctness / serialization test for int4 quantization.""" layer = layers.Dense(units=16) @@ -876,6 +880,7 @@ def test_quantize_int4_by_setting_dtype_policy(self): self.assertLen(layer.variables, expected_num_variables) @pytest.mark.requires_trainable_backend + @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") def test_quantize_int4_when_lora_enabled(self): config = dict(units=16) layer = layers.Dense(**config) From efe244efd6e30c13167dfe2a821d347ebf6ae12b Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Mon, 30 Jun 2025 13:50:09 +0530 Subject: [PATCH 05/23] fixes 'tuple' object has no attribute 'rank' error --- keras/src/quantizers/quantizers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index e4f8032f9741..a5b2c669df1e 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -396,7 +396,9 @@ def pack_int4(arr, axis=0): if arr.dtype != "int8": raise TypeError("Expected int8 tensor for packing") - rank = arr.shape.rank + rank = getattr(arr.shape, "rank", None) + if rank is None: + rank = len(arr.shape) # 1. Bring `axis` to the front. perm = [axis] + [i for i in range(rank) if i != axis] inv_perm = [perm.index(i) for i in range(rank)] @@ -446,7 +448,9 @@ def _pad(x): @keras_export("keras.quantizers.unpack_int4") def unpack_int4(packed, orig_len, axis=0): """Unpack packed int4 tensor (ops) to int8 [-8,7].""" - rank = packed.shape.rank + rank = getattr(packed.shape, "rank", None) + if rank is None: + rank = len(packed.shape) perm = [axis] + [i for i in range(rank) if i != axis] inv_perm = [perm.index(i) for i in range(rank)] transposed = ops.transpose(packed, perm) From 72974102391f8cefa479d5b1d36e2f0d90f2b4b2 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Mon, 30 Jun 2025 20:45:26 +0530 Subject: [PATCH 06/23] fix dtype check to work across backends --- keras/src/quantizers/quantizers.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index a5b2c669df1e..9e81dfa605e0 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -393,8 +393,10 @@ def pack_int4(arr, axis=0): ``orig_rows`` is the original (unpacked) row count prior to any padding that may have been inserted when an odd number of rows is supplied. """ - if arr.dtype != "int8": - raise TypeError("Expected int8 tensor for packing") + if backend.standardize_dtype(arr.dtype) != "int8": + raise TypeError( + "Expected int8 tensor for packing, got {}".format(arr.dtype) + ) rank = getattr(arr.shape, "rank", None) if rank is None: @@ -448,6 +450,11 @@ def _pad(x): @keras_export("keras.quantizers.unpack_int4") def unpack_int4(packed, orig_len, axis=0): """Unpack packed int4 tensor (ops) to int8 [-8,7].""" + if backend.standardize_dtype(packed.dtype) != "int8": + raise TypeError( + "Expected int8 tensor for unpacking, got {}".format(packed.dtype) + ) + rank = getattr(packed.shape, "rank", None) if rank is None: rank = len(packed.shape) From 3a9e26c0d6058dd3a35129cdbeaeaaba22b78ccf Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Mon, 30 Jun 2025 20:55:37 +0530 Subject: [PATCH 07/23] fixed torch compatibility --- keras/src/quantizers/quantizers.py | 29 ++++++----------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 9e81dfa605e0..40f3e9c15fa5 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -411,37 +411,20 @@ def pack_int4(arr, axis=0): needs_pad = ops.equal(ops.mod(rows, 2), 1) def _pad(x): - pad_shape = ops.concatenate( - [ops.array([1]), ops.array(ops.shape(x)[1:])], axis=0 - ) - pad_row = ops.zeros(pad_shape, dtype="int8") + pad_row = ops.zeros_like(x[0:1]) return ops.concatenate([x, pad_row], axis=0) - transposed = ops.cond( - needs_pad, lambda: _pad(transposed), lambda: transposed - ) - rows_padded = ops.shape(transposed)[0] + padded = ops.cond(needs_pad, lambda: _pad(transposed), lambda: transposed) # 3-4. Group in pairs and pack. - flat_tail = ops.reshape(transposed, (rows_padded // 2, 2, -1)) - low = flat_tail[:, 0, :] - high = flat_tail[:, 1, :] + low = padded[::2, ...] + high = padded[1::2, ...] low_u = ops.where(low < 0, low + 16, low) high_u = ops.where(high < 0, high + 16, high) - packed_flat = ops.bitwise_or(low_u, ops.left_shift(high_u, 4)) - packed_flat = ops.cast(packed_flat, "int8") + packed = ops.bitwise_or(low_u, ops.left_shift(high_u, 4)) + packed = ops.cast(packed, "int8") # 5-6. Restore shape. - packed = ops.reshape( - packed_flat, - ops.concatenate( - [ - ops.expand_dims(rows_padded // 2, 0), - ops.array(ops.shape(transposed)[1:]), - ], - axis=0, - ), - ) packed = ops.transpose(packed, inv_perm) # back to original order orig_len = rows # number of slices before padding return packed, ops.shape(packed), orig_len From 9e25042f172cbd7b3a1ba408fa87a39bb62436e6 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Mon, 30 Jun 2025 21:18:37 +0530 Subject: [PATCH 08/23] fixed jax compatibility --- keras/src/quantizers/quantizers.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 40f3e9c15fa5..5c5c7f8450f7 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -410,11 +410,18 @@ def pack_int4(arr, axis=0): rows = ops.shape(transposed)[0] needs_pad = ops.equal(ops.mod(rows, 2), 1) - def _pad(x): - pad_row = ops.zeros_like(x[0:1]) - return ops.concatenate([x, pad_row], axis=0) - - padded = ops.cond(needs_pad, lambda: _pad(transposed), lambda: transposed) + # Always append one zero row so the tensor shape is static for JAX. If no + # padding is actually needed, we'll slice it away later. + zero_row = transposed[:1, ...] * 0 # same dtype/shape (1, …) + padded_full = ops.concatenate([transposed, zero_row], axis=0) + + # Number of valid rows after (possible) padding: + # rows + (1 if needs_pad else 0) + rows_packed = rows + ops.cast(needs_pad, "int32") + + # Slice to keep only the valid rows. This keeps the shape rank static while + # allowing the row count to be dynamic. + padded = padded_full[:rows_packed, ...] # 3-4. Group in pairs and pack. low = padded[::2, ...] From 1aa86de98350ee7d444a918da288d5364a1a1101 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Mon, 30 Jun 2025 21:36:05 +0530 Subject: [PATCH 09/23] removes redundant self._orig_input_dim initialization --- keras/src/layers/core/dense.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index 7d098bffaabe..e5a94f52943f 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -627,8 +627,6 @@ def quantize(self, mode, type_check=True): kernel_value_int4 ) del self._kernel - # Save original input dim for unpacking. - self._orig_input_dim = orig_rows # Build variables using the original kernel shape; _int4_build will # compute the packed shape internally. self.quantized_build(kernel_shape, mode) From f9013ae20cc0f2fa6f5a12122520de6298c30027 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Mon, 30 Jun 2025 22:53:45 +0530 Subject: [PATCH 10/23] improves readability --- keras/src/quantizers/quantizers.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 5c5c7f8450f7..ec652a47c666 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -398,9 +398,8 @@ def pack_int4(arr, axis=0): "Expected int8 tensor for packing, got {}".format(arr.dtype) ) - rank = getattr(arr.shape, "rank", None) - if rank is None: - rank = len(arr.shape) + rank = getattr(arr.shape, "rank", None) or len(arr.shape) + # 1. Bring `axis` to the front. perm = [axis] + [i for i in range(rank) if i != axis] inv_perm = [perm.index(i) for i in range(rank)] @@ -445,9 +444,7 @@ def unpack_int4(packed, orig_len, axis=0): "Expected int8 tensor for unpacking, got {}".format(packed.dtype) ) - rank = getattr(packed.shape, "rank", None) - if rank is None: - rank = len(packed.shape) + rank = getattr(packed.shape, "rank", None) or len(packed.shape) perm = [axis] + [i for i in range(rank) if i != axis] inv_perm = [perm.index(i) for i in range(rank)] transposed = ops.transpose(packed, perm) From f3341562f6e832c2e5a02d5b2c679378f19614b8 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Thu, 3 Jul 2025 21:39:59 +0530 Subject: [PATCH 11/23] W4A8 --- keras/src/layers/core/dense.py | 24 ++++++++++++++++++++++-- keras/src/layers/core/dense_test.py | 2 +- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index e5a94f52943f..5b48b1954627 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -359,9 +359,9 @@ def _int4_build(self, kernel_shape): `ceil(input_dim/2)` because two int4 values are packed into a single int8 byte. """ - # Per-channel quantizer for the last axis (features). + # Per-channel int8 quantizer for the last axis (features). self.inputs_quantizer = quantizers.AbsMaxQuantizer( - axis=-1, value_range=(-8, 7) + axis=-1, ) input_dim, output_dim = kernel_shape packed_rows = (input_dim + 1) // 2 # ceil for odd dims @@ -430,6 +430,16 @@ def _float8_build(self): def _int8_call(self, inputs, training=None): @ops.custom_gradient def matmul_with_inputs_gradient(inputs, kernel, kernel_scale): + """Custom gradient function to handle the int8 quantized weights. + + Automatic differentiation will not know how to handle the int8 + quantized weights. So a custom gradient function is needed to + handle the int8 quantized weights. + + The custom gradient function will use the dequantized kernel to + compute the gradient. + """ + def grad_fn(*args, upstream=None): if upstream is None: (upstream,) = args @@ -467,6 +477,16 @@ def _int4_call(self, inputs, training=None): @ops.custom_gradient def matmul_with_inputs_gradient(inputs, kernel, kernel_scale): + """Custom gradient function for int4 quantized weights. + + Automatic differentiation will not know how to handle the + int4 quantized weights. So a custom gradient function is needed + to handle the int4 quantized weights. + + The custom gradient function will use the dequantized kernel to + compute the gradient. + """ + unpacked_kernel = quantizers.unpack_int4( kernel, self._orig_input_dim ) diff --git a/keras/src/layers/core/dense_test.py b/keras/src/layers/core/dense_test.py index 6056354e64c5..952b9001f1e7 100644 --- a/keras/src/layers/core/dense_test.py +++ b/keras/src/layers/core/dense_test.py @@ -812,7 +812,7 @@ def test_quantize_int4(self): y_quantized = layer(x) mse = ops.mean(ops.square(y_float - y_quantized)) - self.assertLess(mse, 2e-3) # Weak correctness check + self.assertLess(mse, 15e-4) # Weak correctness check # Check model save / load round-trip. model = models.Sequential([layer]) From f1873062a91ee1714374aaa92dcf6d380b77cdfc Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Fri, 4 Jul 2025 13:25:04 +0530 Subject: [PATCH 12/23] added _int4_call stub --- keras/src/layers/layer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 0e61edb5b111..a6f8562a9690 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1317,6 +1317,9 @@ def quantized_call(self, *args, **kwargs): else: raise self._quantization_mode_error(self.quantization_mode) + def _int4_call(self, *args, **kwargs): + raise self._not_implemented_error(self._int4_call) + def _int8_call(self, *args, **kwargs): raise self._not_implemented_error(self._int8_call) From eed432b89e8e93b2290632e033e01bfe2812b121 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Tue, 8 Jul 2025 11:46:37 +0530 Subject: [PATCH 13/23] Fix bug in unpack that promoted tensor to fp32 --- keras/src/quantizers/quantizers.py | 34 +++++++++++++++++-------- keras/src/quantizers/quantizers_test.py | 3 +++ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index ec652a47c666..08b05f1fb543 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -438,10 +438,10 @@ def pack_int4(arr, axis=0): @keras_export("keras.quantizers.unpack_int4") def unpack_int4(packed, orig_len, axis=0): - """Unpack packed int4 tensor (ops) to int8 [-8,7].""" + """Unpack packed int4 tensor (ops) to int8 in range [-8, 7].""" if backend.standardize_dtype(packed.dtype) != "int8": raise TypeError( - "Expected int8 tensor for unpacking, got {}".format(packed.dtype) + f"Expected int8 tensor for unpacking, got {packed.dtype}" ) rank = getattr(packed.shape, "rank", None) or len(packed.shape) @@ -449,17 +449,31 @@ def unpack_int4(packed, orig_len, axis=0): inv_perm = [perm.index(i) for i in range(rank)] transposed = ops.transpose(packed, perm) - # 1. Split nibbles. - low = ops.bitwise_and(transposed, 0x0F) - high = ops.bitwise_and(ops.right_shift(transposed, 4), 0x0F) - to_signed = lambda x: ops.where(x < 8, x, x - 16) + # 1. split nibbles + mask = ops.array([0x0F], dtype="int8") # int8 arrays + low = ops.bitwise_and(transposed, mask) + high = ops.bitwise_and(ops.right_shift(transposed, 4), mask) + + eight = ops.array([8], dtype="int8") + sixteen = ops.array([16], dtype="int8") + + def to_signed(x): + # keep the whole where-expression in int8, + # then cast once more to be certain + return ops.cast( + ops.where(x < eight, x, x - sixteen), + "int8", + ) + low = to_signed(low) high = to_signed(high) - # 2. Interleave. - stacked = ops.stack([low, high], axis=1) # (pairs, 2, …) + # 2. interleave & reshape + stacked = ops.stack([low, high], axis=1) # (pairs, 2, ...) unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(transposed)[1:])) - # 3. Remove possible padding and restore layout. + # 3. remove padding & restore original layout unpacked = unpacked[:orig_len, ...] - return ops.transpose(unpacked, inv_perm) + unpacked = ops.transpose(unpacked, inv_perm) + + return unpacked # dtype is int8 diff --git a/keras/src/quantizers/quantizers_test.py b/keras/src/quantizers/quantizers_test.py index 63084a91605b..28213a7cb994 100644 --- a/keras/src/quantizers/quantizers_test.py +++ b/keras/src/quantizers/quantizers_test.py @@ -139,6 +139,9 @@ def test_pack_unpack_int4(self, shape, axis): # Unpack the tensor unpacked = quantizers.unpack_int4(packed, orig_len, axis=axis) + # Verify that the packed tensor is int8 + self.assertDType(packed, "int8") + # The unpacked tensor should be the same as the original tensor self.assertAllClose(unpacked, arr) From 248fcc86ac5dab5ec4ffaec7b1171d89c85f0e93 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Tue, 8 Jul 2025 13:09:15 +0530 Subject: [PATCH 14/23] add missing dtype assertion to quantizer test --- keras/src/quantizers/quantizers_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/keras/src/quantizers/quantizers_test.py b/keras/src/quantizers/quantizers_test.py index 28213a7cb994..dd70dd6bc016 100644 --- a/keras/src/quantizers/quantizers_test.py +++ b/keras/src/quantizers/quantizers_test.py @@ -142,6 +142,9 @@ def test_pack_unpack_int4(self, shape, axis): # Verify that the packed tensor is int8 self.assertDType(packed, "int8") + # Verify that the unpacked tensor is int8 + self.assertDType(unpacked, "int8") + # The unpacked tensor should be the same as the original tensor self.assertAllClose(unpacked, arr) From 64950761ea23c2b70589eb5a19499a64a38ced03 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Wed, 9 Jul 2025 06:04:26 +0530 Subject: [PATCH 15/23] docstring fixes --- keras/src/quantizers/quantizers.py | 72 +++++++++++++++++++++++++----- 1 file changed, 60 insertions(+), 12 deletions(-) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 08b05f1fb543..d97ecaa2e584 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -378,20 +378,36 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype): @keras_export("keras.quantizers.pack_int4") def pack_int4(arr, axis=0): - """Pack an int4 tensor into an int8 tensor with packed nibbles. + """ + Pack an int4 tensor into an int8 tensor with packed nibbles. + + The input values must already be int8 in the signed range `[-8, 7]` and + represent the desired int4 values. Packing is performed along the specified + axis (default is 0). - Accepts a Keras-compatible tensor. The input values must already be int8 - in the signed range ``[-8, 7]`` and represent the desired int4 values. - Packing is performed along axis 0: + For every two consecutive rows, the **low nibble** of the output byte + stores the value from the first row, and the **high nibble** stores + the value from the second row. - * For every two consecutive rows, the **low nibble** of the output byte - stores the value from the first row, and the **high nibble** stores - the value from the second row. + Args: + arr: An int8 tensor containing int4 values in the range `[-8, 7]`. + axis: The axis along which to pack the tensor. Defaults to 0. - Returns a tuple ``(packed, packed_shape, orig_rows)`` where ``packed`` - is the packed ``int8`` tensor, ``packed_shape`` is its shape, and - ``orig_rows`` is the original (unpacked) row count prior to any padding - that may have been inserted when an odd number of rows is supplied. + Returns: + tuple: A tuple `(packed, packed_shape, orig_rows)` where `packed` is + the packed int8 tensor with int4 values stored in nibbles, + `packed_shape` is the shape of the packed tensor, and `orig_rows` is + the original (unpacked) row count prior to any padding that may have + been inserted when an odd number of rows is supplied. + + Example: + >>> import numpy as np + >>> from keras.quantizers import pack_int4, unpack_int4 + >>> arr = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) + >>> packed, packed_shape, orig_len = pack_int4(arr, axis=0) + >>> unpacked = unpack_int4(packed, orig_len, axis=0) + >>> np.allclose(arr, unpacked) + True """ if backend.standardize_dtype(arr.dtype) != "int8": raise TypeError( @@ -438,7 +454,39 @@ def pack_int4(arr, axis=0): @keras_export("keras.quantizers.unpack_int4") def unpack_int4(packed, orig_len, axis=0): - """Unpack packed int4 tensor (ops) to int8 in range [-8, 7].""" + """ + Unpack a packed int4 tensor (with values stored in nibbles of int8) + back to an int8 tensor in the range [-8, 7]. + + This function reverses the packing performed by `pack_int4`, restoring + the original int8 tensor (values in the range [-8, 7]) from a packed int8 + tensor where each element contains two int4 values (one in the lower nibble, + one in the upper nibble). + + The function restores the original axis order and removes any + padding that was added during packing. + + Args: + packed: An int8 tensor containing packed int4 values along the + specified axis. Each int8 value encodes two int4 values. + orig_len: The original (unpadded) length of the axis that was + packed. This is used to remove any padding that may have + been added during packing to ensure an even number of rows. + axis: The axis along which the tensor was packed. Defaults to 0. + + Returns: + unpacked: An int8 tensor with the same shape as the original + (unpacked) tensor, with values in the range [-8, 7]. + + Example: + >>> import numpy as np + >>> from keras.quantizers import pack_int4, unpack_int4 + >>> arr = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) + >>> packed, packed_shape, orig_len = pack_int4(arr, axis=0) + >>> unpacked = unpack_int4(packed, orig_len, axis=0) + >>> np.allclose(arr, unpacked) + True + """ if backend.standardize_dtype(packed.dtype) != "int8": raise TypeError( f"Expected int8 tensor for unpacking, got {packed.dtype}" From 0413b362d1afe41c1bad2725fddf4aa5bb9c01ab Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Wed, 9 Jul 2025 11:52:48 +0530 Subject: [PATCH 16/23] docstring fixes --- keras/src/quantizers/quantizers.py | 92 +++++++++++++++++++++++++++--- 1 file changed, 85 insertions(+), 7 deletions(-) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index d97ecaa2e584..9d4251567d0f 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -403,10 +403,49 @@ def pack_int4(arr, axis=0): Example: >>> import numpy as np >>> from keras.quantizers import pack_int4, unpack_int4 - >>> arr = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) - >>> packed, packed_shape, orig_len = pack_int4(arr, axis=0) + + # Example with axis=0 + # Original array has shape (3, 2) + >>> original_array = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) + + # Pack the array along axis 0. Since the length of axis 0 (3) is + # odd, it will be padded to a length of 4. The packed array will + # have a shape of (ceil(3/2), 2) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=0) + >>> print("Packed array:\n", packed) + Packed array: + [[ 45 -121] + [ 1 0]] + + # Now, unpack the array back to its original form >>> unpacked = unpack_int4(packed, orig_len, axis=0) - >>> np.allclose(arr, unpacked) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7] + [ 2 -8] + [ 1 0]] + >>> np.allclose(original_array, unpacked) + True + + # Example with axis=1 + # Original array has shape (2, 3) + >>> original_array = np.array([[-3, 7, 2], [-8, 1, 0]], dtype=np.int8) + + # Pack along axis 1. Length of axis 1 (3) is padded to 4. + # The new shape is (2, ceil(3/2)) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=1) + >>> print("Packed array:\n", packed) + Packed array: + [[ 125 2] + [ 24 0]] + + # Unpack the array + >>> unpacked = unpack_int4(packed, orig_len, axis=1) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7 2] + [-8 1 0]] + >>> np.allclose(original_array, unpacked) True """ if backend.standardize_dtype(arr.dtype) != "int8": @@ -427,7 +466,7 @@ def pack_int4(arr, axis=0): # Always append one zero row so the tensor shape is static for JAX. If no # padding is actually needed, we'll slice it away later. - zero_row = transposed[:1, ...] * 0 # same dtype/shape (1, …) + zero_row = transposed[:1, ...] * 0 # same dtype/shape (1, ...) padded_full = ops.concatenate([transposed, zero_row], axis=0) # Number of valid rows after (possible) padding: @@ -481,10 +520,49 @@ def unpack_int4(packed, orig_len, axis=0): Example: >>> import numpy as np >>> from keras.quantizers import pack_int4, unpack_int4 - >>> arr = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) - >>> packed, packed_shape, orig_len = pack_int4(arr, axis=0) + + # Example with axis=0 + # Original array has shape (3, 2) + >>> original_array = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) + + # Pack the array along axis 0. Since the length of axis 0 (3) is + # odd, it will be padded to a length of 4. The packed array will + # have a shape of (ceil(3/2), 2) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=0) + >>> print("Packed array:\n", packed) + Packed array: + [[ 45 -121] + [ 1 0]] + + # Now, unpack the array back to its original form >>> unpacked = unpack_int4(packed, orig_len, axis=0) - >>> np.allclose(arr, unpacked) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7] + [ 2 -8] + [ 1 0]] + >>> np.allclose(original_array, unpacked) + True + + # Example with axis=1 + # Original array has shape (2, 3) + >>> original_array = np.array([[-3, 7, 2], [-8, 1, 0]], dtype=np.int8) + + # Pack along axis 1. Length of axis 1 (3) is padded to 4. + # The new shape is (2, ceil(3/2)) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=1) + >>> print("Packed array:\n", packed) + Packed array: + [[ 125 2] + [ 24 0]] + + # Unpack the array + >>> unpacked = unpack_int4(packed, orig_len, axis=1) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7 2] + [-8 1 0]] + >>> np.allclose(original_array, unpacked) True """ if backend.standardize_dtype(packed.dtype) != "int8": From 052f7b618defbee99e79b9c8bbc476c2bdddc3b1 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Wed, 9 Jul 2025 13:47:16 +0530 Subject: [PATCH 17/23] introduces fastpath for dense unpack --- keras/src/quantizers/quantizers.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 9d4251567d0f..14b56d2f2f04 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -571,6 +571,33 @@ def unpack_int4(packed, orig_len, axis=0): ) rank = getattr(packed.shape, "rank", None) or len(packed.shape) + + # Fast path for the most common case in Dense layers + if axis == 0 and rank == 2: + # The result of the bitwise op is a wider dtype (e.g., int32). + low_unpacked = ops.bitwise_and(packed, 0x0F) + high_unpacked = ops.bitwise_and(ops.right_shift(packed, 4), 0x0F) + + # Convert values from [0, 15] to [-8, 7]. + low_signed = ops.where( + low_unpacked < 8, low_unpacked, low_unpacked - 16 + ) + high_signed = ops.where( + high_unpacked < 8, high_unpacked, high_unpacked - 16 + ) + + # Cast back to int8 as the final step. + low = ops.cast(low_signed, "int8") + high = ops.cast(high_signed, "int8") + + # Interleave and reshape + stacked = ops.stack([low, high], axis=1) + unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(packed)[1:])) + + # Remove padding and return + return unpacked[:orig_len, ...] + + # General case perm = [axis] + [i for i in range(rank) if i != axis] inv_perm = [perm.index(i) for i in range(rank)] transposed = ops.transpose(packed, perm) From a87687d36fb072f3fe6f8f78b0cd8c673caab386 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Wed, 9 Jul 2025 17:10:29 +0530 Subject: [PATCH 18/23] handle negative axis for pack/unpack --- keras/src/quantizers/quantizers.py | 29 +++++++++++++++---------- keras/src/quantizers/quantizers_test.py | 6 +++++ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 14b56d2f2f04..8e491249a2df 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -455,6 +455,9 @@ def pack_int4(arr, axis=0): rank = getattr(arr.shape, "rank", None) or len(arr.shape) + if axis < 0: + axis += rank + # 1. Bring `axis` to the front. perm = [axis] + [i for i in range(rank) if i != axis] inv_perm = [perm.index(i) for i in range(rank)] @@ -480,8 +483,11 @@ def pack_int4(arr, axis=0): # 3-4. Group in pairs and pack. low = padded[::2, ...] high = padded[1::2, ...] - low_u = ops.where(low < 0, low + 16, low) - high_u = ops.where(high < 0, high + 16, high) + + mask = ops.array([0x0F], dtype="int8") + low_u = ops.bitwise_and(low, mask) + high_u = ops.bitwise_and(high, mask) + packed = ops.bitwise_or(low_u, ops.left_shift(high_u, 4)) packed = ops.cast(packed, "int8") @@ -572,6 +578,9 @@ def unpack_int4(packed, orig_len, axis=0): rank = getattr(packed.shape, "rank", None) or len(packed.shape) + if axis < 0: + axis += rank + # Fast path for the most common case in Dense layers if axis == 0 and rank == 2: # The result of the bitwise op is a wider dtype (e.g., int32). @@ -602,7 +611,7 @@ def unpack_int4(packed, orig_len, axis=0): inv_perm = [perm.index(i) for i in range(rank)] transposed = ops.transpose(packed, perm) - # 1. split nibbles + # 1. Split nibbles. mask = ops.array([0x0F], dtype="int8") # int8 arrays low = ops.bitwise_and(transposed, mask) high = ops.bitwise_and(ops.right_shift(transposed, 4), mask) @@ -611,21 +620,19 @@ def unpack_int4(packed, orig_len, axis=0): sixteen = ops.array([16], dtype="int8") def to_signed(x): - # keep the whole where-expression in int8, - # then cast once more to be certain - return ops.cast( - ops.where(x < eight, x, x - sixteen), - "int8", - ) + return ops.where(x < eight, x, x - sixteen) low = to_signed(low) high = to_signed(high) - # 2. interleave & reshape + # 2. Interleave and reshape. stacked = ops.stack([low, high], axis=1) # (pairs, 2, ...) unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(transposed)[1:])) - # 3. remove padding & restore original layout + # 3. Cast back to int8 as the final step. + unpacked = ops.cast(unpacked, "int8") + + # 4. Remove padding and restore original layout. unpacked = unpacked[:orig_len, ...] unpacked = ops.transpose(unpacked, inv_perm) diff --git a/keras/src/quantizers/quantizers_test.py b/keras/src/quantizers/quantizers_test.py index dd70dd6bc016..60ec20a85606 100644 --- a/keras/src/quantizers/quantizers_test.py +++ b/keras/src/quantizers/quantizers_test.py @@ -110,6 +110,8 @@ def test_quantize_and_dequantize(self): @parameterized.named_parameters( ("even_rows", (4, 5), 0), ("odd_rows", (5, 5), 0), + ("even_rows_axis_0_negative", (4, 5), -1), + ("odd_rows_axis_0_negative", (5, 5), -1), ("even_rows_axis_1", (4, 6), 1), ("odd_rows_axis_1", (4, 7), 1), ("3d_even_rows_axis_0", (4, 5, 3), 0), @@ -126,6 +128,10 @@ def test_quantize_and_dequantize(self): ("4d_even_rows_axis_1", (2, 4, 5, 4), 1), ("4d_even_rows_axis_2", (2, 4, 5, 4), 2), ("4d_even_rows_axis_3", (2, 4, 5, 4), 3), + ("4d_even_rows_axis_0_negative", (2, 4, 5, 4), -1), + ("4d_even_rows_axis_1_negative", (2, 4, 5, 4), -2), + ("4d_even_rows_axis_2_negative", (2, 4, 5, 4), -3), + ("4d_even_rows_axis_3_negative", (2, 4, 5, 4), -4), ) def test_pack_unpack_int4(self, shape, axis): # Create a random tensor with int4 values [-8, 7] From 9e2901cd42cdeaa02f69c87c6b0122ac9a2c0bc9 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Thu, 10 Jul 2025 21:37:53 +0530 Subject: [PATCH 19/23] standardize docs formatting --- keras/src/layers/core/dense.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index 5b48b1954627..0d988a9a8e64 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -366,7 +366,7 @@ def _int4_build(self, kernel_shape): input_dim, output_dim = kernel_shape packed_rows = (input_dim + 1) // 2 # ceil for odd dims - # Kernel is stored **packed**: each int8 byte contains two int4 values. + # Kernel is stored *packed*: each int8 byte contains two int4 values. self._kernel = self.add_weight( name="kernel", shape=(packed_rows, output_dim), @@ -675,19 +675,19 @@ def _get_kernel_with_merged_lora(self): # original float representation before merging with the LoRA # update, and then pack it again after requantization. if self.quantization_mode == "int4": - # 1) Unpack packed int4 tensor to int8 range [-8, 7]. + # 1. Unpack packed int4 tensor to int8 range [-8, 7]. unpacked_kernel = quantizers.unpack_int4( kernel_value, self._orig_input_dim ) - # 2) De-scale to recover float32 kernel. + # 2. De-scale to recover float32 kernel. kernel_value_fp = ops.divide(unpacked_kernel, kernel_scale) - # 3) Merge LoRA delta in float32 domain. + # 3. Merge LoRA delta in float32 domain. kernel_value_fp = ops.add( kernel_value_fp, (self.lora_alpha / self.lora_rank) * ops.matmul(self.lora_kernel_a, self.lora_kernel_b), ) - # 4) Re-quantize to int4 (values still held in int8 dtype). + # 4. Re-quantize to int4 (values still held in int8 dtype). kernel_int4, kernel_scale = quantizers.abs_max_quantize( kernel_value_fp, axis=0, @@ -696,7 +696,7 @@ def _get_kernel_with_merged_lora(self): to_numpy=True, ) kernel_scale = ops.squeeze(kernel_scale, axis=0) - # 5) Pack the int4 values back into the compact int8 layout. + # 5. Pack the int4 values back into the compact int8 layout. kernel_value, _, _ = quantizers.pack_int4(kernel_int4) else: # int8 path (regular): unpacking not required. From 519e6d7276ccc41a3ff92193685cbea54e5e76c6 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Fri, 11 Jul 2025 05:06:46 +0530 Subject: [PATCH 20/23] fix docstring format --- keras/src/layers/core/dense.py | 1 + keras/src/quantizers/quantizers.py | 207 +++++++++++++++-------------- 2 files changed, 106 insertions(+), 102 deletions(-) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index 0d988a9a8e64..2e986f726860 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -354,6 +354,7 @@ def _int8_build(self, kernel_shape): def _int4_build(self, kernel_shape): """Build variables for int4 quantization. + `kernel_shape` is the *original* float32 kernel shape `(input_dim, units)`. We allocate the stored kernel with rows `ceil(input_dim/2)` because two int4 values are packed into a single diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 8e491249a2df..509f4b6b7003 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -378,8 +378,7 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype): @keras_export("keras.quantizers.pack_int4") def pack_int4(arr, axis=0): - """ - Pack an int4 tensor into an int8 tensor with packed nibbles. + """Pack an int4 tensor into an int8 tensor with packed nibbles. The input values must already be int8 in the signed range `[-8, 7]` and represent the desired int4 values. Packing is performed along the specified @@ -395,58 +394,61 @@ def pack_int4(arr, axis=0): Returns: tuple: A tuple `(packed, packed_shape, orig_rows)` where `packed` is - the packed int8 tensor with int4 values stored in nibbles, - `packed_shape` is the shape of the packed tensor, and `orig_rows` is - the original (unpacked) row count prior to any padding that may have - been inserted when an odd number of rows is supplied. + the packed int8 tensor with int4 values stored in nibbles, + `packed_shape` is the shape of the packed tensor, and `orig_rows` + is the original (unpacked) row count prior to any padding that may + have been inserted when an odd number of rows is supplied. Example: - >>> import numpy as np - >>> from keras.quantizers import pack_int4, unpack_int4 - - # Example with axis=0 - # Original array has shape (3, 2) - >>> original_array = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) - - # Pack the array along axis 0. Since the length of axis 0 (3) is - # odd, it will be padded to a length of 4. The packed array will - # have a shape of (ceil(3/2), 2) = (2, 2). - >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=0) - >>> print("Packed array:\n", packed) - Packed array: - [[ 45 -121] - [ 1 0]] - - # Now, unpack the array back to its original form - >>> unpacked = unpack_int4(packed, orig_len, axis=0) - >>> print("Unpacked array:\n", unpacked) - Unpacked array: - [[-3 7] - [ 2 -8] - [ 1 0]] - >>> np.allclose(original_array, unpacked) - True - - # Example with axis=1 - # Original array has shape (2, 3) - >>> original_array = np.array([[-3, 7, 2], [-8, 1, 0]], dtype=np.int8) - - # Pack along axis 1. Length of axis 1 (3) is padded to 4. - # The new shape is (2, ceil(3/2)) = (2, 2). - >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=1) - >>> print("Packed array:\n", packed) - Packed array: - [[ 125 2] - [ 24 0]] - - # Unpack the array - >>> unpacked = unpack_int4(packed, orig_len, axis=1) - >>> print("Unpacked array:\n", unpacked) - Unpacked array: - [[-3 7 2] - [-8 1 0]] - >>> np.allclose(original_array, unpacked) - True + + ```python + >>> import numpy as np + >>> from keras.quantizers import pack_int4, unpack_int4 + + # Example with axis=0 + # Original array has shape (3, 2) + >>> original_array = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) + + # Pack the array along axis 0. Since the length of axis 0 (3) is + # odd, it will be padded to a length of 4. The packed array will + # have a shape of (ceil(3/2), 2) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=0) + >>> print("Packed array:\n", packed) + Packed array: + [[ 45 -121] + [ 1 0]] + + # Now, unpack the array back to its original form + >>> unpacked = unpack_int4(packed, orig_len, axis=0) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7] + [ 2 -8] + [ 1 0]] + >>> np.allclose(original_array, unpacked) + True + + # Example with axis=1 + # Original array has shape (2, 3) + >>> original_array = np.array([[-3, 7, 2], [-8, 1, 0]], dtype=np.int8) + + # Pack along axis 1. Length of axis 1 (3) is padded to 4. + # The new shape is (2, ceil(3/2)) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=1) + >>> print("Packed array:\n", packed) + Packed array: + [[ 125 2] + [ 24 0]] + + # Unpack the array + >>> unpacked = unpack_int4(packed, orig_len, axis=1) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7 2] + [-8 1 0]] + >>> np.allclose(original_array, unpacked) + True + ``` """ if backend.standardize_dtype(arr.dtype) != "int8": raise TypeError( @@ -499,9 +501,7 @@ def pack_int4(arr, axis=0): @keras_export("keras.quantizers.unpack_int4") def unpack_int4(packed, orig_len, axis=0): - """ - Unpack a packed int4 tensor (with values stored in nibbles of int8) - back to an int8 tensor in the range [-8, 7]. + """Unpack a packed int4 back to an int8 tensor in the range [-8, 7]. This function reverses the packing performed by `pack_int4`, restoring the original int8 tensor (values in the range [-8, 7]) from a packed int8 @@ -521,55 +521,58 @@ def unpack_int4(packed, orig_len, axis=0): Returns: unpacked: An int8 tensor with the same shape as the original - (unpacked) tensor, with values in the range [-8, 7]. + (unpacked) tensor, with values in the range [-8, 7]. Example: - >>> import numpy as np - >>> from keras.quantizers import pack_int4, unpack_int4 - - # Example with axis=0 - # Original array has shape (3, 2) - >>> original_array = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) - - # Pack the array along axis 0. Since the length of axis 0 (3) is - # odd, it will be padded to a length of 4. The packed array will - # have a shape of (ceil(3/2), 2) = (2, 2). - >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=0) - >>> print("Packed array:\n", packed) - Packed array: - [[ 45 -121] - [ 1 0]] - - # Now, unpack the array back to its original form - >>> unpacked = unpack_int4(packed, orig_len, axis=0) - >>> print("Unpacked array:\n", unpacked) - Unpacked array: - [[-3 7] - [ 2 -8] - [ 1 0]] - >>> np.allclose(original_array, unpacked) - True - - # Example with axis=1 - # Original array has shape (2, 3) - >>> original_array = np.array([[-3, 7, 2], [-8, 1, 0]], dtype=np.int8) - - # Pack along axis 1. Length of axis 1 (3) is padded to 4. - # The new shape is (2, ceil(3/2)) = (2, 2). - >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=1) - >>> print("Packed array:\n", packed) - Packed array: - [[ 125 2] - [ 24 0]] - - # Unpack the array - >>> unpacked = unpack_int4(packed, orig_len, axis=1) - >>> print("Unpacked array:\n", unpacked) - Unpacked array: - [[-3 7 2] - [-8 1 0]] - >>> np.allclose(original_array, unpacked) - True + + ```python + >>> import numpy as np + >>> from keras.quantizers import pack_int4, unpack_int4 + + # Example with axis=0 + # Original array has shape (3, 2) + >>> original_array = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) + + # Pack the array along axis 0. Since the length of axis 0 (3) is + # odd, it will be padded to a length of 4. The packed array will + # have a shape of (ceil(3/2), 2) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=0) + >>> print("Packed array:\n", packed) + Packed array: + [[ 45 -121] + [ 1 0]] + + # Now, unpack the array back to its original form + >>> unpacked = unpack_int4(packed, orig_len, axis=0) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7] + [ 2 -8] + [ 1 0]] + >>> np.allclose(original_array, unpacked) + True + + # Example with axis=1 + # Original array has shape (2, 3) + >>> original_array = np.array([[-3, 7, 2], [-8, 1, 0]], dtype=np.int8) + + # Pack along axis 1. Length of axis 1 (3) is padded to 4. + # The new shape is (2, ceil(3/2)) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=1) + >>> print("Packed array:\n", packed) + Packed array: + [[ 125 2] + [ 24 0]] + + # Unpack the array + >>> unpacked = unpack_int4(packed, orig_len, axis=1) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7 2] + [-8 1 0]] + >>> np.allclose(original_array, unpacked) + True + ``` """ if backend.standardize_dtype(packed.dtype) != "int8": raise TypeError( From 41cac4bc39bd3f4d7cbfa87ac77f876a6c7899f7 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Fri, 11 Jul 2025 09:24:56 +0530 Subject: [PATCH 21/23] Reduce duplication in _get_kernel_with_merged_lora --- keras/src/layers/core/dense.py | 94 +++++++++++++++++++++++----------- 1 file changed, 63 insertions(+), 31 deletions(-) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index 2e986f726860..b6184fb7bc49 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -667,49 +667,81 @@ def quantize(self, mode, type_check=True): self.dtype_policy = policy def _get_kernel_with_merged_lora(self): + """Returns the kernel with LoRA matrices merged, for serialization. + + This method is called by `save_own_variables` to produce a single + kernel tensor that includes the adaptations from LoRA. This is useful + for deploying the model or for continuing training after permanently + applying the LoRA update. + + If the layer is quantized (`int8` or `int4`), the process is: + 1. Dequantize the base kernel to float. + 2. Compute the LoRA delta (`lora_kernel_a @ lora_kernel_b`) and add + it to the dequantized kernel. + 3. Re-quantize the merged result back to the original quantized + type (`int8` or packed `int4`), calculating a new scale factor. + + If the layer is not quantized, this method returns the result of the + `kernel` property (which computes the merge in floating-point) and a + scale of `None`. + + If LoRA is not enabled, it returns the original kernel and scale + without modification. + + Returns: + A tuple `(kernel_value, kernel_scale)`: + `kernel_value`: The merged kernel. A quantized tensor if + quantization is active, otherwise a high precision tensor. + `kernel_scale`: The quantization scale for the merged kernel. + This is `None` if the layer is not quantized. + """ if self.dtype_policy.quantization_mode is not None: kernel_value = self._kernel kernel_scale = self.kernel_scale if self.lora_enabled: - # For int4, `_kernel` is stored in a packed representation - # (two int4 values per int8 byte). We need to unpack it to the - # original float representation before merging with the LoRA - # update, and then pack it again after requantization. + # Dequantize kernel to float if self.quantization_mode == "int4": - # 1. Unpack packed int4 tensor to int8 range [-8, 7]. unpacked_kernel = quantizers.unpack_int4( kernel_value, self._orig_input_dim ) - # 2. De-scale to recover float32 kernel. - kernel_value_fp = ops.divide(unpacked_kernel, kernel_scale) - # 3. Merge LoRA delta in float32 domain. - kernel_value_fp = ops.add( - kernel_value_fp, - (self.lora_alpha / self.lora_rank) - * ops.matmul(self.lora_kernel_a, self.lora_kernel_b), + float_kernel = ops.divide( + ops.cast(unpacked_kernel, self.compute_dtype), + kernel_scale, ) - # 4. Re-quantize to int4 (values still held in int8 dtype). - kernel_int4, kernel_scale = quantizers.abs_max_quantize( - kernel_value_fp, - axis=0, - value_range=(-8, 7), - dtype="int8", - to_numpy=True, + quant_range = (-8, 7) + elif self.quantization_mode == "int8": + float_kernel = ops.divide( + ops.cast(kernel_value, self.compute_dtype), kernel_scale ) - kernel_scale = ops.squeeze(kernel_scale, axis=0) - # 5. Pack the int4 values back into the compact int8 layout. - kernel_value, _, _ = quantizers.pack_int4(kernel_int4) + quant_range = (-127, 127) else: - # int8 path (regular): unpacking not required. - kernel_value = ops.divide(kernel_value, kernel_scale) - kernel_value = ops.add( - kernel_value, - (self.lora_alpha / self.lora_rank) - * ops.matmul(self.lora_kernel_a, self.lora_kernel_b), + raise ValueError( + "Unsupported quantization mode: " + f"{self.quantization_mode}" ) - kernel_value, kernel_scale = quantizers.abs_max_quantize( - kernel_value, axis=0, to_numpy=True + + # Merge LoRA weights in float domain + lora_delta = (self.lora_alpha / self.lora_rank) * ops.matmul( + self.lora_kernel_a, self.lora_kernel_b + ) + merged_float_kernel = ops.add(float_kernel, lora_delta) + + # Requantize + requantized_kernel, kernel_scale = quantizers.abs_max_quantize( + merged_float_kernel, + axis=0, + value_range=quant_range, + dtype="int8", + to_numpy=True, + ) + kernel_scale = ops.squeeze(kernel_scale, axis=0) + + # Pack if int4 + if self.quantization_mode == "int4": + kernel_value, _, _ = quantizers.pack_int4( + requantized_kernel ) - kernel_scale = ops.squeeze(kernel_scale, axis=0) + else: + kernel_value = requantized_kernel return kernel_value, kernel_scale return self.kernel, None From 0d5c3bd7b1c77372133b8bdee839a7d211a3d1ae Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Fri, 11 Jul 2025 09:31:27 +0530 Subject: [PATCH 22/23] remove unnecessary cast ops --- keras/src/quantizers/quantizers.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 509f4b6b7003..e1c842fe00e8 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -486,7 +486,7 @@ def pack_int4(arr, axis=0): low = padded[::2, ...] high = padded[1::2, ...] - mask = ops.array([0x0F], dtype="int8") + mask = ops.array(0x0F, dtype="int8") low_u = ops.bitwise_and(low, mask) high_u = ops.bitwise_and(high, mask) @@ -587,8 +587,9 @@ def unpack_int4(packed, orig_len, axis=0): # Fast path for the most common case in Dense layers if axis == 0 and rank == 2: # The result of the bitwise op is a wider dtype (e.g., int32). - low_unpacked = ops.bitwise_and(packed, 0x0F) - high_unpacked = ops.bitwise_and(ops.right_shift(packed, 4), 0x0F) + mask = ops.array(0x0F, dtype=packed.dtype) + low_unpacked = ops.bitwise_and(packed, mask) + high_unpacked = ops.bitwise_and(ops.right_shift(packed, 4), mask) # Convert values from [0, 15] to [-8, 7]. low_signed = ops.where( @@ -598,12 +599,8 @@ def unpack_int4(packed, orig_len, axis=0): high_unpacked < 8, high_unpacked, high_unpacked - 16 ) - # Cast back to int8 as the final step. - low = ops.cast(low_signed, "int8") - high = ops.cast(high_signed, "int8") - # Interleave and reshape - stacked = ops.stack([low, high], axis=1) + stacked = ops.stack([low_signed, high_signed], axis=1) unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(packed)[1:])) # Remove padding and return @@ -615,12 +612,12 @@ def unpack_int4(packed, orig_len, axis=0): transposed = ops.transpose(packed, perm) # 1. Split nibbles. - mask = ops.array([0x0F], dtype="int8") # int8 arrays + mask = ops.array(0x0F, dtype="int8") # int8 arrays low = ops.bitwise_and(transposed, mask) high = ops.bitwise_and(ops.right_shift(transposed, 4), mask) - eight = ops.array([8], dtype="int8") - sixteen = ops.array([16], dtype="int8") + eight = ops.array(8, dtype="int8") + sixteen = ops.array(16, dtype="int8") def to_signed(x): return ops.where(x < eight, x, x - sixteen) @@ -632,9 +629,6 @@ def to_signed(x): stacked = ops.stack([low, high], axis=1) # (pairs, 2, ...) unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(transposed)[1:])) - # 3. Cast back to int8 as the final step. - unpacked = ops.cast(unpacked, "int8") - # 4. Remove padding and restore original layout. unpacked = unpacked[:orig_len, ...] unpacked = ops.transpose(unpacked, inv_perm) From 98fa1ed653bd0196ec209686015aac097128c843 Mon Sep 17 00:00:00 2001 From: Jyotinder Singh <33001894+JyotinderSingh@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:18:42 +0530 Subject: [PATCH 23/23] removes unused var --- keras/src/layers/core/dense.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index b6184fb7bc49..725137da8f0c 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -644,9 +644,7 @@ def quantize(self, mode, type_check=True): ) kernel_scale = ops.squeeze(kernel_scale, axis=0) # 2. Pack two int4 values into a single int8 byte. - packed_kernel_value, _, orig_rows = quantizers.pack_int4( - kernel_value_int4 - ) + packed_kernel_value, _, _ = quantizers.pack_int4(kernel_value_int4) del self._kernel # Build variables using the original kernel shape; _int4_build will # compute the packed shape internally.