You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
**Summary:** This commit adds a new multi-step QAT API with the
main goal of simplifying the existing UX. The new API uses the
same `QATConfig` for both the prepare and convert steps, and
automatically infers the fake quantization configs based on
a PTQ base config provided by the user:
```
from torchao.quantization import (
quantize_,
Int8DynamicActivationInt4WeightConfig
)
from torchao.quantization.qat import QATConfig
\# prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
qat_config = QATConfig(base_config, step="prepare")
quantize_(m, qat_config)
\# train (not shown)
\# convert
quantize_(m, QATConfig(base_config, step="convert"))
```
The main improvements include:
- A single config for both prepare and convert steps
- A single quantize_ for convert (instead of 2)
- No chance for incompatible prepare vs convert configs
- Much less boilerplate code for most common use case
- Simpler config names
For less common use cases such as experimentation, users can
still specify arbitrary fake quantization configs for
activations and/or weights as before. This is still important
since there may not always be a corresponding PTQ base config.
For example:
```
from torchao.quantization import quantize_
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)
\# train and convert same as above (not shown)
```
**BC-breaking notes:** This change by itself is technically not
BC-breaking since we keep around the old path, but will become
so when we deprecate and remove the old path in the future.
Before:
```
\# prepare
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
quantize_(model, qat_config)
\# train (not shown)
\# convert
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
```
After: (see above)
**Test Plan:**
```
python test/quantization/test_qat.py
```
ghstack-source-id: 7adbc7c
Pull Request resolved: #2629
Copy file name to clipboardExpand all lines: README.md
+11-6Lines changed: 11 additions & 6 deletions
Original file line number
Diff line number
Diff line change
@@ -179,12 +179,17 @@ With this quantization flow, we achieve **67% VRAM reduction and 12-20% speedup*
179
179
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with [TorchTune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [QAT README](torchao/quantization/qat/README.md) and the [original blog](https://pytorch.org/blog/quantization-aware-training/):
180
180
181
181
```python
182
-
from torchao.quantization import quantize_
183
-
from torchao.quantization.qat import IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig
Users can also combine LoRA + QAT to speed up training by [1.89x](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700) compared to vanilla QAT using this [fine-tuning recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py).
0 commit comments