-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Add int4 Quantization Support #21435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add int4 Quantization Support #21435
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21435 +/- ##
==========================================
+ Coverage 74.94% 82.78% +7.83%
==========================================
Files 565 565
Lines 55224 55404 +180
Branches 8610 8635 +25
==========================================
+ Hits 41386 45864 +4478
+ Misses 11880 7425 -4455
- Partials 1958 2115 +157
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
410977a
to
71c116a
Compare
c1a58b7
to
777b5e6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! The code generally looks good to me. What is the performance profile? How did you benchmark the change?
I hadn't yet benchmarked the code. I've now created two micro-benchmarks and have linked them in the PR description, please take a look! |
Summary
This PR introduces support for
int4
weight-only quantization for theDense
layer. The implementation includes the necessary logic for packing and unpackingint4
values, performing the quantized matrix multiplication, and ensuring compatibility with features like LoRA.The code currently implements W4A8 quantization scheme.
Description
The core changes include:
Support for
int4
quantization mode.Packing and Unpacking Utilities:
pack_int4
takes anint8
tensor (representingint4
values) and packs two 4-bit values into a singleint8
byte.unpack_int4
performs the reverse operation, unpacking theint8
tensor back into anint8
tensor ofint4
values.Dense
Layer Modifications:_int4_build
: Builds a packedkernel
ofint8
dtype and akernel_scale
variable. The original input dimension is saved in_orig_input_dim
to handle unpacking correctly._int4_call
: Defines the forward pass for theint4
quantized layer. It uses acustom_gradient
to perform the matrix multiplication with the unpacked kernel and correctly computes the gradients with respect to the original inputs.quantize
method now handlesmode="int4"
. It quantizes the float weights toint4
values and then packs them usingpack_int4
.enable_lora
method correctly determines the input dimension for the LoRA matrices when the layer isint4
quantized by using the saved_orig_input_dim
._get_kernel_with_merged_lora
method handles the unpacking of theint4
kernel before merging the LoRA weights, followed by re-quantization and re-packing.Testing
int4
quantization indense_test.py
. These tests cover basic correctness, serialization (saving/loading models), behavior with LoRA enabled, and various edge cases.pack_int4
andunpack_int4
functions inquantizers_test.py
to ensure they work correctly for various tensor shapes and axes.Benchmarking
Note: Results collected with warmed-up GPUs and pre-loaded models and kernels.
Micro Benchmark with OPT 125M using KerasHub
[colab link]
Micro Benchmark with BERT Classifier using KerasHub
[colab link]
Limitation
The current implementation performs a kernel unpack on every forward-pass (to unpack the int4 kernel from it's packed int8 representation where each byte stores two nibbles). This means that we lose some memory savings at runtime along with some performance penalty.
We may be able to work around this in the future by writing custom kernels which operate directly on the packed int4 representation.
Further work