diff --git a/keras/api/_tf_keras/keras/quantizers/__init__.py b/keras/api/_tf_keras/keras/quantizers/__init__.py index 6400d8faf634..48cf9b034560 100644 --- a/keras/api/_tf_keras/keras/quantizers/__init__.py +++ b/keras/api/_tf_keras/keras/quantizers/__init__.py @@ -19,6 +19,8 @@ from keras.src.quantizers.quantizers import ( fake_quant_with_min_max_vars as fake_quant_with_min_max_vars, ) +from keras.src.quantizers.quantizers import pack_int4 as pack_int4 from keras.src.quantizers.quantizers import ( quantize_and_dequantize as quantize_and_dequantize, ) +from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4 diff --git a/keras/api/quantizers/__init__.py b/keras/api/quantizers/__init__.py index 6400d8faf634..48cf9b034560 100644 --- a/keras/api/quantizers/__init__.py +++ b/keras/api/quantizers/__init__.py @@ -19,6 +19,8 @@ from keras.src.quantizers.quantizers import ( fake_quant_with_min_max_vars as fake_quant_with_min_max_vars, ) +from keras.src.quantizers.quantizers import pack_int4 as pack_int4 from keras.src.quantizers.quantizers import ( quantize_and_dequantize as quantize_and_dequantize, ) +from keras.src.quantizers.quantizers import unpack_int4 as unpack_int4 diff --git a/keras/src/dtype_policies/dtype_policy.py b/keras/src/dtype_policies/dtype_policy.py index 9d37ac49f9f3..8182a1e45aa4 100644 --- a/keras/src/dtype_policies/dtype_policy.py +++ b/keras/src/dtype_policies/dtype_policy.py @@ -3,7 +3,7 @@ from keras.src.api_export import keras_export from keras.src.backend.common import global_state -QUANTIZATION_MODES = ("int8", "float8") +QUANTIZATION_MODES = ("int8", "float8", "int4") @keras_export( @@ -350,7 +350,7 @@ def _get_quantized_dtype_policy_by_str(policy): f"Received: policy={policy}" ) mode, source_name = split_name - if policy.startswith("int8"): + if policy.startswith("int8") or policy.startswith("int4"): return QuantizedDTypePolicy(mode, source_name) elif policy.startswith("float8"): return QuantizedFloat8DTypePolicy(mode, source_name) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index ed22ee264553..725137da8f0c 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -2,7 +2,6 @@ from keras.src import activations from keras.src import constraints -from keras.src import dtype_policies from keras.src import initializers from keras.src import ops from keras.src import quantizers @@ -110,9 +109,10 @@ def build(self, input_shape): kernel_shape = (input_shape[-1], self.units) if self.quantization_mode: self.quantized_build(kernel_shape, mode=self.quantization_mode) - if self.quantization_mode != "int8": - # If the layer is quantized to int8, `self._kernel` will be added - # in `self._int8_build`. Therefore, we skip it here. + if self.quantization_mode not in ("int8", "int4"): + # If the layer is quantized to int8 or int4, `self._kernel` will be + # added in `self._int8_build` or `_int4_build`. Therefore, we skip + # it here. self._kernel = self.add_weight( name="kernel", shape=kernel_shape, @@ -182,9 +182,22 @@ def enable_lora( "lora is already enabled. This can only be done once per layer." ) self._tracker.unlock() + # Determine the correct input dimension for the LoRA A matrix. When + # the layer has been int4-quantized, `self._kernel` stores a *packed* + # representation whose first dimension is `ceil(input_dim/2)`. We + # saved the true, *unpacked* input dimension in `self._orig_input_dim` + # during quantization. Use it if available; otherwise fall back to the + # first dimension of `self.kernel`. + if self.quantization_mode == "int4" and hasattr( + self, "_orig_input_dim" + ): + input_dim_for_lora = self._orig_input_dim + else: + input_dim_for_lora = self.kernel.shape[0] + self.lora_kernel_a = self.add_weight( name="lora_kernel_a", - shape=(self.kernel.shape[0], rank), + shape=(input_dim_for_lora, rank), initializer=initializers.get(a_initializer), regularizer=self.kernel_regularizer, ) @@ -211,7 +224,7 @@ def save_own_variables(self, store): if self.use_bias: target_variables.append(self.bias) if self.quantization_mode is not None: - if self.quantization_mode == "int8": + if self.quantization_mode in ("int8", "int4"): target_variables.append(kernel_scale) elif self.quantization_mode == "float8": target_variables.append(self.inputs_scale) @@ -237,7 +250,7 @@ def load_own_variables(self, store): if self.use_bias: target_variables.append(self.bias) if self.quantization_mode is not None: - if self.quantization_mode == "int8": + if self.quantization_mode in ("int8", "int4"): target_variables.append(self.kernel_scale) elif self.quantization_mode == "float8": target_variables.append(self.inputs_scale) @@ -315,6 +328,8 @@ def _check_load_own_variables(self, store): def quantized_build(self, kernel_shape, mode): if mode == "int8": self._int8_build(kernel_shape) + elif mode == "int4": + self._int4_build(kernel_shape) elif mode == "float8": self._float8_build() else: @@ -337,6 +352,39 @@ def _int8_build(self, kernel_shape): trainable=False, ) + def _int4_build(self, kernel_shape): + """Build variables for int4 quantization. + + `kernel_shape` is the *original* float32 kernel shape + `(input_dim, units)`. We allocate the stored kernel with rows + `ceil(input_dim/2)` because two int4 values are packed into a single + int8 byte. + """ + # Per-channel int8 quantizer for the last axis (features). + self.inputs_quantizer = quantizers.AbsMaxQuantizer( + axis=-1, + ) + input_dim, output_dim = kernel_shape + packed_rows = (input_dim + 1) // 2 # ceil for odd dims + + # Kernel is stored *packed*: each int8 byte contains two int4 values. + self._kernel = self.add_weight( + name="kernel", + shape=(packed_rows, output_dim), + initializer="zeros", + dtype="int8", + trainable=False, + ) + # One scale per output unit (per-channel). + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=(self.units,), + initializer="ones", + trainable=False, + ) + # Record original input_dim for unpacking at runtime. + self._orig_input_dim = input_dim + def _float8_build(self): from keras.src.dtype_policies import QuantizedFloat8DTypePolicy @@ -383,6 +431,16 @@ def _float8_build(self): def _int8_call(self, inputs, training=None): @ops.custom_gradient def matmul_with_inputs_gradient(inputs, kernel, kernel_scale): + """Custom gradient function to handle the int8 quantized weights. + + Automatic differentiation will not know how to handle the int8 + quantized weights. So a custom gradient function is needed to + handle the int8 quantized weights. + + The custom gradient function will use the dequantized kernel to + compute the gradient. + """ + def grad_fn(*args, upstream=None): if upstream is None: (upstream,) = args @@ -415,6 +473,59 @@ def grad_fn(*args, upstream=None): x = self.activation(x) return x + def _int4_call(self, inputs, training=None): + """Forward pass for int4 quantized Dense layer.""" + + @ops.custom_gradient + def matmul_with_inputs_gradient(inputs, kernel, kernel_scale): + """Custom gradient function for int4 quantized weights. + + Automatic differentiation will not know how to handle the + int4 quantized weights. So a custom gradient function is needed + to handle the int4 quantized weights. + + The custom gradient function will use the dequantized kernel to + compute the gradient. + """ + + unpacked_kernel = quantizers.unpack_int4( + kernel, self._orig_input_dim + ) + + def grad_fn(*args, upstream=None): + if upstream is None: + (upstream,) = args + float_kernel = ops.divide( + ops.cast(unpacked_kernel, dtype=self.compute_dtype), + kernel_scale, + ) + inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel)) + return (inputs_grad, None, None) + + inputs, inputs_scale = self.inputs_quantizer(inputs) + x = ops.matmul(inputs, unpacked_kernel) + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale)) + return x, grad_fn + + x = matmul_with_inputs_gradient( + inputs, + ops.convert_to_tensor(self._kernel), + ops.convert_to_tensor(self.kernel_scale), + ) + + if self.lora_enabled: + lora_x = ops.matmul(inputs, self.lora_kernel_a) + lora_x = ops.matmul(lora_x, self.lora_kernel_b) + x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x) + + # Add bias and activation + if self.bias is not None: + x = ops.add(x, self.bias) + if self.activation is not None: + x = self.activation(x) + return x + def _float8_call(self, inputs, training=None): if self.lora_enabled: raise NotImplementedError( @@ -518,32 +629,117 @@ def quantize(self, mode, type_check=True): ) kernel_scale = ops.squeeze(kernel_scale, axis=0) del self._kernel - self.quantized_build(kernel_shape, mode) - if mode == "int8": + # Build variables for int8 mode + self.quantized_build(kernel_shape, mode) self._kernel.assign(kernel_value) self.kernel_scale.assign(kernel_scale) + elif mode == "int4": + # 1. Quantize to int4 values (still int8 dtype, range [-8,7]) + kernel_value_int4, kernel_scale = quantizers.abs_max_quantize( + self._kernel, + axis=0, + value_range=(-8, 7), + dtype="int8", + to_numpy=True, + ) + kernel_scale = ops.squeeze(kernel_scale, axis=0) + # 2. Pack two int4 values into a single int8 byte. + packed_kernel_value, _, _ = quantizers.pack_int4(kernel_value_int4) + del self._kernel + # Build variables using the original kernel shape; _int4_build will + # compute the packed shape internally. + self.quantized_build(kernel_shape, mode) + # Assign packed values. + self._kernel.assign(packed_kernel_value) + self.kernel_scale.assign(kernel_scale) + elif mode == "float8": + self.quantized_build(kernel_shape, mode) + else: + raise self._quantization_mode_error(mode) - # Set new dtype policy + # Set new dtype policy only for modes that already have a policy. if self.dtype_policy.quantization_mode is None: + from keras.src import dtype_policies # local import to avoid cycle + policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}") self.dtype_policy = policy def _get_kernel_with_merged_lora(self): + """Returns the kernel with LoRA matrices merged, for serialization. + + This method is called by `save_own_variables` to produce a single + kernel tensor that includes the adaptations from LoRA. This is useful + for deploying the model or for continuing training after permanently + applying the LoRA update. + + If the layer is quantized (`int8` or `int4`), the process is: + 1. Dequantize the base kernel to float. + 2. Compute the LoRA delta (`lora_kernel_a @ lora_kernel_b`) and add + it to the dequantized kernel. + 3. Re-quantize the merged result back to the original quantized + type (`int8` or packed `int4`), calculating a new scale factor. + + If the layer is not quantized, this method returns the result of the + `kernel` property (which computes the merge in floating-point) and a + scale of `None`. + + If LoRA is not enabled, it returns the original kernel and scale + without modification. + + Returns: + A tuple `(kernel_value, kernel_scale)`: + `kernel_value`: The merged kernel. A quantized tensor if + quantization is active, otherwise a high precision tensor. + `kernel_scale`: The quantization scale for the merged kernel. + This is `None` if the layer is not quantized. + """ if self.dtype_policy.quantization_mode is not None: kernel_value = self._kernel kernel_scale = self.kernel_scale if self.lora_enabled: - # Dequantize & quantize to merge lora weights into int8 kernel - # Note that this is a lossy compression - kernel_value = ops.divide(kernel_value, kernel_scale) - kernel_value = ops.add( - kernel_value, - (self.lora_alpha / self.lora_rank) - * ops.matmul(self.lora_kernel_a, self.lora_kernel_b), + # Dequantize kernel to float + if self.quantization_mode == "int4": + unpacked_kernel = quantizers.unpack_int4( + kernel_value, self._orig_input_dim + ) + float_kernel = ops.divide( + ops.cast(unpacked_kernel, self.compute_dtype), + kernel_scale, + ) + quant_range = (-8, 7) + elif self.quantization_mode == "int8": + float_kernel = ops.divide( + ops.cast(kernel_value, self.compute_dtype), kernel_scale + ) + quant_range = (-127, 127) + else: + raise ValueError( + "Unsupported quantization mode: " + f"{self.quantization_mode}" + ) + + # Merge LoRA weights in float domain + lora_delta = (self.lora_alpha / self.lora_rank) * ops.matmul( + self.lora_kernel_a, self.lora_kernel_b ) - kernel_value, kernel_scale = quantizers.abs_max_quantize( - kernel_value, axis=0, to_numpy=True + merged_float_kernel = ops.add(float_kernel, lora_delta) + + # Requantize + requantized_kernel, kernel_scale = quantizers.abs_max_quantize( + merged_float_kernel, + axis=0, + value_range=quant_range, + dtype="int8", + to_numpy=True, ) kernel_scale = ops.squeeze(kernel_scale, axis=0) + + # Pack if int4 + if self.quantization_mode == "int4": + kernel_value, _, _ = quantizers.pack_int4( + requantized_kernel + ) + else: + kernel_value = requantized_kernel return kernel_value, kernel_scale return self.kernel, None diff --git a/keras/src/layers/core/dense_test.py b/keras/src/layers/core/dense_test.py index ba1073cd97ce..952b9001f1e7 100644 --- a/keras/src/layers/core/dense_test.py +++ b/keras/src/layers/core/dense_test.py @@ -505,6 +505,7 @@ def test_quantize_invalid_mode(self, mode): @parameterized.named_parameters( ("int8", "int8_from_mixed_bfloat16", 1, 2), ("float8", "float8_from_mixed_bfloat16", 8, 0), + ("int4", "int4_from_mixed_bfloat16", 1, 2), ) @pytest.mark.requires_trainable_backend @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") @@ -787,3 +788,174 @@ def test_quantize_float8_inference(self): y_inference = layer(x, training=False) y_training = layer(x, training=True) self.assertAllClose(y_inference, y_training) + + @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") + def test_quantize_int4(self): + """Basic correctness / serialization test for int4 quantization.""" + layer = layers.Dense(units=16) + layer.build((None, 8)) + + # Reference (float32) output. + x = np.random.random((2, 8)) + y_float = layer(x) + + # Quantize to int4 and validate kernel dtype / scale dtype. + layer.quantize("int4") + self.assertEqual( + backend.standardize_dtype(layer._kernel.dtype), + "int8", # Packed int4 values are stored as int8 + ) + self.assertEqual( + backend.standardize_dtype(layer.kernel_scale.dtype), + layer.variable_dtype, + ) + + y_quantized = layer(x) + mse = ops.mean(ops.square(y_float - y_quantized)) + self.assertLess(mse, 15e-4) # Weak correctness check + + # Check model save / load round-trip. + model = models.Sequential([layer]) + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_int4_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + # Check weights-only save / load round-trip. + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_int4_model.weights.h5" + ) + model.save_weights(temp_filepath) + new_model = models.Sequential([layers.Dense(units=16)]) + new_model.build((None, 8)) + new_model.quantize("int4") + new_model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + def test_quantize_int4_on_unbuilt_layer(self): + layer = layers.Dense(units=2) + with self.assertRaisesRegex( + ValueError, "Cannot quantize a layer that isn't yet built." + ): + layer.quantize("int4") + + def test_quantize_int4_on_subclass(self): + class MyDense(layers.Dense): + pass + + layer = MyDense(units=16) + layer.build((None, 8)) + with self.assertRaises(NotImplementedError): + layer.quantize("int4") + + # It should succeed when `type_check=False`. + layer.quantize("int4", type_check=False) + + def test_quantize_int4_when_already_quantized(self): + layer = layers.Dense(units=2) + layer.build((None, 2)) + layer.quantize("int4") + for m in ["int8", "float8", "int4"]: + with self.assertRaisesRegex( + ValueError, "is already quantized with dtype_policy=" + ): + layer.quantize(m) + + layer = layers.Dense(units=2, dtype="int4_from_float32") + layer.build((None, 2)) + for m in ["int8", "float8", "int4"]: + with self.assertRaisesRegex( + ValueError, "is already quantized with dtype_policy=" + ): + layer.quantize(m) + + def test_quantize_int4_by_setting_dtype_policy(self): + policy = "int4_from_float32" + expected_num_variables = 3 # bias + packed kernel + scale + layer = layers.Dense(units=2) + layer.build((None, 2)) + layer.dtype_policy = policy + self.assertLen(layer.variables, expected_num_variables) + + @pytest.mark.requires_trainable_backend + @pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason="Segfault") + def test_quantize_int4_when_lora_enabled(self): + config = dict(units=16) + layer = layers.Dense(**config) + layer.build((None, 8)) + layer.enable_lora(4) + layer.quantize("int4") + self.assertLen(layer.trainable_weights, 3) + self.assertLen(layer.non_trainable_weights, 2) + if backend.backend() == "torch": + self.assertLen(layer.torch_params, 5) + + # Try calling fit() + init_lora_a_kernel_value = layer.lora_kernel_a.numpy() + init_lora_b_kernel_value = layer.lora_kernel_b.numpy() + x = np.random.random((64, 8)) + y = np.random.random((64, 16)) + model = models.Sequential([layer]) + model.compile(optimizer="sgd", loss="mse") + model.fit(x, y, epochs=2) + + final_lora_a_kernel_value = layer.lora_kernel_a.numpy() + final_lora_b_kernel_value = layer.lora_kernel_b.numpy() + diff_a = np.max( + np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value) + ) + diff_b = np.max( + np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value) + ) + self.assertGreater(diff_a, 0.0) + self.assertGreater(diff_b, 0.0) + + # Save & reload full model + model = models.Sequential([layer]) + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_int4_lora_model.keras" + ) + model.save(temp_filepath) + new_model = saving.load_model(temp_filepath) + self.assertTrue(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Save & reload weights only + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_int4_lora_model.weights.h5" + ) + model.save_weights(temp_filepath) + new_model = models.Sequential([layers.Dense(**config)]) + new_model.build((None, 8)) + new_model.quantize("int4") + new_model.load_weights(temp_filepath) + self.assertFalse(new_model.layers[0].lora_enabled) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Try loading a normal checkpoint into a lora model + new_model.save_weights(temp_filepath) + model.load_weights(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5) + + # Test export and TFSMLayer reloading when using tensorflow backend + if backend.backend() == "tensorflow": + import tensorflow as tf + + temp_filepath = os.path.join(self.get_temp_dir(), "exported_model") + ref_input = tf.random.normal((2, 8)) + ref_output = model(ref_input) + model.export(temp_filepath, format="tf_saved_model") + reloaded_layer = export.TFSMLayer(temp_filepath) + self.assertAllClose( + reloaded_layer(ref_input), ref_output, atol=1e-7 + ) + self.assertLen(reloaded_layer.weights, len(model.weights)) + self.assertLen( + reloaded_layer.trainable_weights, len(model.trainable_weights) + ) + self.assertLen( + reloaded_layer.non_trainable_weights, + len(model.non_trainable_weights), + ) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index eaff1a8376a2..a6f8562a9690 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1312,9 +1312,14 @@ def quantized_call(self, *args, **kwargs): return self._int8_call(*args, **kwargs) elif self.quantization_mode == "float8": return self._float8_call(*args, **kwargs) + elif self.quantization_mode == "int4": + return self._int4_call(*args, **kwargs) else: raise self._quantization_mode_error(self.quantization_mode) + def _int4_call(self, *args, **kwargs): + raise self._not_implemented_error(self._int4_call) + def _int8_call(self, *args, **kwargs): raise self._not_implemented_error(self._int8_call) diff --git a/keras/src/quantizers/__init__.py b/keras/src/quantizers/__init__.py index dc7643e1e82a..586530204588 100644 --- a/keras/src/quantizers/__init__.py +++ b/keras/src/quantizers/__init__.py @@ -7,7 +7,9 @@ from keras.src.quantizers.quantizers import compute_float8_amax_history from keras.src.quantizers.quantizers import compute_float8_scale from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars +from keras.src.quantizers.quantizers import pack_int4 from keras.src.quantizers.quantizers import quantize_and_dequantize +from keras.src.quantizers.quantizers import unpack_int4 from keras.src.saving import serialization_lib from keras.src.utils.naming import to_snake_case diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 26ae800ce8f0..e1c842fe00e8 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -374,3 +374,263 @@ def quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype): # Dequantize x = ops.multiply(ops.cast(x, compute_dtype), ops.cast(scale, compute_dtype)) return x + + +@keras_export("keras.quantizers.pack_int4") +def pack_int4(arr, axis=0): + """Pack an int4 tensor into an int8 tensor with packed nibbles. + + The input values must already be int8 in the signed range `[-8, 7]` and + represent the desired int4 values. Packing is performed along the specified + axis (default is 0). + + For every two consecutive rows, the **low nibble** of the output byte + stores the value from the first row, and the **high nibble** stores + the value from the second row. + + Args: + arr: An int8 tensor containing int4 values in the range `[-8, 7]`. + axis: The axis along which to pack the tensor. Defaults to 0. + + Returns: + tuple: A tuple `(packed, packed_shape, orig_rows)` where `packed` is + the packed int8 tensor with int4 values stored in nibbles, + `packed_shape` is the shape of the packed tensor, and `orig_rows` + is the original (unpacked) row count prior to any padding that may + have been inserted when an odd number of rows is supplied. + + Example: + + ```python + >>> import numpy as np + >>> from keras.quantizers import pack_int4, unpack_int4 + + # Example with axis=0 + # Original array has shape (3, 2) + >>> original_array = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) + + # Pack the array along axis 0. Since the length of axis 0 (3) is + # odd, it will be padded to a length of 4. The packed array will + # have a shape of (ceil(3/2), 2) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=0) + >>> print("Packed array:\n", packed) + Packed array: + [[ 45 -121] + [ 1 0]] + + # Now, unpack the array back to its original form + >>> unpacked = unpack_int4(packed, orig_len, axis=0) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7] + [ 2 -8] + [ 1 0]] + >>> np.allclose(original_array, unpacked) + True + + # Example with axis=1 + # Original array has shape (2, 3) + >>> original_array = np.array([[-3, 7, 2], [-8, 1, 0]], dtype=np.int8) + + # Pack along axis 1. Length of axis 1 (3) is padded to 4. + # The new shape is (2, ceil(3/2)) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=1) + >>> print("Packed array:\n", packed) + Packed array: + [[ 125 2] + [ 24 0]] + + # Unpack the array + >>> unpacked = unpack_int4(packed, orig_len, axis=1) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7 2] + [-8 1 0]] + >>> np.allclose(original_array, unpacked) + True + ``` + """ + if backend.standardize_dtype(arr.dtype) != "int8": + raise TypeError( + "Expected int8 tensor for packing, got {}".format(arr.dtype) + ) + + rank = getattr(arr.shape, "rank", None) or len(arr.shape) + + if axis < 0: + axis += rank + + # 1. Bring `axis` to the front. + perm = [axis] + [i for i in range(rank) if i != axis] + inv_perm = [perm.index(i) for i in range(rank)] + transposed = ops.transpose(arr, perm) + + # 2. Pad to even length. + rows = ops.shape(transposed)[0] + needs_pad = ops.equal(ops.mod(rows, 2), 1) + + # Always append one zero row so the tensor shape is static for JAX. If no + # padding is actually needed, we'll slice it away later. + zero_row = transposed[:1, ...] * 0 # same dtype/shape (1, ...) + padded_full = ops.concatenate([transposed, zero_row], axis=0) + + # Number of valid rows after (possible) padding: + # rows + (1 if needs_pad else 0) + rows_packed = rows + ops.cast(needs_pad, "int32") + + # Slice to keep only the valid rows. This keeps the shape rank static while + # allowing the row count to be dynamic. + padded = padded_full[:rows_packed, ...] + + # 3-4. Group in pairs and pack. + low = padded[::2, ...] + high = padded[1::2, ...] + + mask = ops.array(0x0F, dtype="int8") + low_u = ops.bitwise_and(low, mask) + high_u = ops.bitwise_and(high, mask) + + packed = ops.bitwise_or(low_u, ops.left_shift(high_u, 4)) + packed = ops.cast(packed, "int8") + + # 5-6. Restore shape. + packed = ops.transpose(packed, inv_perm) # back to original order + orig_len = rows # number of slices before padding + return packed, ops.shape(packed), orig_len + + +@keras_export("keras.quantizers.unpack_int4") +def unpack_int4(packed, orig_len, axis=0): + """Unpack a packed int4 back to an int8 tensor in the range [-8, 7]. + + This function reverses the packing performed by `pack_int4`, restoring + the original int8 tensor (values in the range [-8, 7]) from a packed int8 + tensor where each element contains two int4 values (one in the lower nibble, + one in the upper nibble). + + The function restores the original axis order and removes any + padding that was added during packing. + + Args: + packed: An int8 tensor containing packed int4 values along the + specified axis. Each int8 value encodes two int4 values. + orig_len: The original (unpadded) length of the axis that was + packed. This is used to remove any padding that may have + been added during packing to ensure an even number of rows. + axis: The axis along which the tensor was packed. Defaults to 0. + + Returns: + unpacked: An int8 tensor with the same shape as the original + (unpacked) tensor, with values in the range [-8, 7]. + + Example: + + ```python + >>> import numpy as np + >>> from keras.quantizers import pack_int4, unpack_int4 + + # Example with axis=0 + # Original array has shape (3, 2) + >>> original_array = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8) + + # Pack the array along axis 0. Since the length of axis 0 (3) is + # odd, it will be padded to a length of 4. The packed array will + # have a shape of (ceil(3/2), 2) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=0) + >>> print("Packed array:\n", packed) + Packed array: + [[ 45 -121] + [ 1 0]] + + # Now, unpack the array back to its original form + >>> unpacked = unpack_int4(packed, orig_len, axis=0) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7] + [ 2 -8] + [ 1 0]] + >>> np.allclose(original_array, unpacked) + True + + # Example with axis=1 + # Original array has shape (2, 3) + >>> original_array = np.array([[-3, 7, 2], [-8, 1, 0]], dtype=np.int8) + + # Pack along axis 1. Length of axis 1 (3) is padded to 4. + # The new shape is (2, ceil(3/2)) = (2, 2). + >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=1) + >>> print("Packed array:\n", packed) + Packed array: + [[ 125 2] + [ 24 0]] + + # Unpack the array + >>> unpacked = unpack_int4(packed, orig_len, axis=1) + >>> print("Unpacked array:\n", unpacked) + Unpacked array: + [[-3 7 2] + [-8 1 0]] + >>> np.allclose(original_array, unpacked) + True + ``` + """ + if backend.standardize_dtype(packed.dtype) != "int8": + raise TypeError( + f"Expected int8 tensor for unpacking, got {packed.dtype}" + ) + + rank = getattr(packed.shape, "rank", None) or len(packed.shape) + + if axis < 0: + axis += rank + + # Fast path for the most common case in Dense layers + if axis == 0 and rank == 2: + # The result of the bitwise op is a wider dtype (e.g., int32). + mask = ops.array(0x0F, dtype=packed.dtype) + low_unpacked = ops.bitwise_and(packed, mask) + high_unpacked = ops.bitwise_and(ops.right_shift(packed, 4), mask) + + # Convert values from [0, 15] to [-8, 7]. + low_signed = ops.where( + low_unpacked < 8, low_unpacked, low_unpacked - 16 + ) + high_signed = ops.where( + high_unpacked < 8, high_unpacked, high_unpacked - 16 + ) + + # Interleave and reshape + stacked = ops.stack([low_signed, high_signed], axis=1) + unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(packed)[1:])) + + # Remove padding and return + return unpacked[:orig_len, ...] + + # General case + perm = [axis] + [i for i in range(rank) if i != axis] + inv_perm = [perm.index(i) for i in range(rank)] + transposed = ops.transpose(packed, perm) + + # 1. Split nibbles. + mask = ops.array(0x0F, dtype="int8") # int8 arrays + low = ops.bitwise_and(transposed, mask) + high = ops.bitwise_and(ops.right_shift(transposed, 4), mask) + + eight = ops.array(8, dtype="int8") + sixteen = ops.array(16, dtype="int8") + + def to_signed(x): + return ops.where(x < eight, x, x - sixteen) + + low = to_signed(low) + high = to_signed(high) + + # 2. Interleave and reshape. + stacked = ops.stack([low, high], axis=1) # (pairs, 2, ...) + unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(transposed)[1:])) + + # 4. Remove padding and restore original layout. + unpacked = unpacked[:orig_len, ...] + unpacked = ops.transpose(unpacked, inv_perm) + + return unpacked # dtype is int8 diff --git a/keras/src/quantizers/quantizers_test.py b/keras/src/quantizers/quantizers_test.py index 1fc7c94df7d6..60ec20a85606 100644 --- a/keras/src/quantizers/quantizers_test.py +++ b/keras/src/quantizers/quantizers_test.py @@ -107,6 +107,60 @@ def test_quantize_and_dequantize(self): # A loose assertion due to an expected quantization error self.assertAllClose(qdq_values, values, atol=5e-1) + @parameterized.named_parameters( + ("even_rows", (4, 5), 0), + ("odd_rows", (5, 5), 0), + ("even_rows_axis_0_negative", (4, 5), -1), + ("odd_rows_axis_0_negative", (5, 5), -1), + ("even_rows_axis_1", (4, 6), 1), + ("odd_rows_axis_1", (4, 7), 1), + ("3d_even_rows_axis_0", (4, 5, 3), 0), + ("3d_odd_rows_axis_0", (5, 5, 3), 0), + ("3d_even_rows_axis_1", (4, 6, 3), 1), + ("3d_odd_rows_axis_1", (4, 7, 3), 1), + ("3d_even_rows_axis_2", (4, 5, 6), 2), + ("3d_odd_rows_axis_2", (4, 5, 7), 2), + ("4d_odd_rows_axis_0", (2, 3, 5, 4), 0), + ("4d_odd_rows_axis_1", (2, 3, 5, 4), 1), + ("4d_odd_rows_axis_2", (2, 3, 5, 4), 2), + ("4d_odd_rows_axis_3", (2, 3, 5, 4), 3), + ("4d_even_rows_axis_0", (2, 4, 5, 4), 0), + ("4d_even_rows_axis_1", (2, 4, 5, 4), 1), + ("4d_even_rows_axis_2", (2, 4, 5, 4), 2), + ("4d_even_rows_axis_3", (2, 4, 5, 4), 3), + ("4d_even_rows_axis_0_negative", (2, 4, 5, 4), -1), + ("4d_even_rows_axis_1_negative", (2, 4, 5, 4), -2), + ("4d_even_rows_axis_2_negative", (2, 4, 5, 4), -3), + ("4d_even_rows_axis_3_negative", (2, 4, 5, 4), -4), + ) + def test_pack_unpack_int4(self, shape, axis): + # Create a random tensor with int4 values [-8, 7] + arr = ops.cast( + ops.floor(random.uniform(shape, minval=-8, maxval=8)), "int8" + ) + + # Pack the tensor + packed, packed_shape, orig_len = quantizers.pack_int4(arr, axis=axis) + + # Unpack the tensor + unpacked = quantizers.unpack_int4(packed, orig_len, axis=axis) + + # Verify that the packed tensor is int8 + self.assertDType(packed, "int8") + + # Verify that the unpacked tensor is int8 + self.assertDType(unpacked, "int8") + + # The unpacked tensor should be the same as the original tensor + self.assertAllClose(unpacked, arr) + + # Test the packed shape + expected_packed_shape = list(shape) + expected_packed_shape[axis] = (expected_packed_shape[axis] + 1) // 2 + self.assertEqual( + list(ops.convert_to_numpy(packed_shape)), expected_packed_shape + ) + @parameterized.named_parameters( ("per_tensor", None), ("per_channel", -1),