Skip to content

New multi-step QAT API #2629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,17 @@ With this quantization flow, we achieve **67% VRAM reduction and 12-20% speedup*
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with [TorchTune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [QAT README](torchao/quantization/qat/README.md) and the [original blog](https://pytorch.org/blog/quantization-aware-training/):

```python
from torchao.quantization import quantize_
from torchao.quantization.qat import IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
quantize_(my_model, qat_config)
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
from torchao.quantization.qat import QATConfig

# prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(my_model, QATConfig(base_config, step="prepare"))

# train model (not shown)

# convert
quantize_(my_model, QATConfig(base_config, step="convert"))
```

Users can also combine LoRA + QAT to speed up training by [1.89x](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700) compared to vanilla QAT using this [fine-tuning recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py).
Expand Down
11 changes: 7 additions & 4 deletions docs/source/api_ref_qat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ torchao.quantization.qat

.. currentmodule:: torchao.quantization.qat

QAT Configs for quantize_
Main Config for quantize_
---------------------------------------
For a full example of how to use QAT with our main `quantize_` API,
please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md#quantize_-api-recommended>`__.
Expand All @@ -15,29 +15,32 @@ please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao
:toctree: generated/
:nosignatures:

IntXQuantizationAwareTrainingConfig
FromIntXQuantizationAwareTrainingConfig
QATConfig
QATStep

Custom QAT APIs
---------------
.. autosummary::
:toctree: generated/
:nosignatures:

FakeQuantizeConfigBase
IntxFakeQuantizeConfig
FakeQuantizedLinear
FakeQuantizedEmbedding
FakeQuantizer
linear.enable_linear_fake_quant
linear.disable_linear_fake_quant

Legacy QAT Quantizers
Legacy QAT APIs
---------------------

.. autosummary::
:toctree: generated/
:nosignatures:

IntXQuantizationAwareTrainingConfig
FromIntXQuantizationAwareTrainingConfig
Int4WeightOnlyQATQuantizer
linear.Int4WeightOnlyQATLinear
Int8DynActInt4WeightQATQuantizer
Expand Down
35 changes: 11 additions & 24 deletions docs/source/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,21 +205,14 @@ because we are not actually casting the fake quantized values.

.. code:: py

from torchao.quantization import (
quantize_,
)
from torchao.quantization.qat import (
FakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
from torchao.quantization.qat import QATConfig

model = get_model()

# prepare: insert fake quantization ops
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
quantize_(model, qat_config)
# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear`
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(model, QATConfig(base_config, step="prepare"))

# fine-tune
train_loop(model)
Expand All @@ -232,18 +225,12 @@ The next step is to actually quantize the model:

.. code:: py

from torchao.quantization import (
Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.qat import (
FromIntXQuantizationAwareTrainingConfig,
)
from torchao.quantization import Int8DynamicActivationInt4WeightConfig

# convert: transform fake quantization ops into actual quantized ops
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
# quantized activation and weight tensor subclasses
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
# convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config`
quantize_(model, QATConfig(base_config, step="convert"))

# inference or generate

Now our model is ready for serving, and will typically have higher quantized
accuracy than if we did not apply the prepare step (fake quantization) during
Expand Down
184 changes: 136 additions & 48 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
ComposableQATQuantizer,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
QATConfig,
QATStep,
initialize_fake_quantizers,
)
from torchao.quantization.qat.embedding import (
Expand All @@ -59,7 +61,7 @@
_get_qmin_qmax,
)
from torchao.quantization.quant_api import (
int8_dynamic_activation_int4_weight,
Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.quant_primitives import (
MappingType,
Expand Down Expand Up @@ -1261,11 +1263,67 @@ def test_qat_prototype_bc(self):
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api_standalone(self):
def test_qat_config_init(self):
"""
Test that the correct errors are thrown if `QATConfig` is not instantiated properly.
"""
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
fq_config = IntxFakeQuantizeConfig(torch.int8, "per_channel")

# OK
QATConfig(base_config, step="prepare")
QATConfig(base_config, step="convert")
QATConfig(base_config, step=QATStep.PREPARE)
QATConfig(base_config, step=QATStep.CONVERT)
QATConfig(activation_config=fq_config, weight_config=fq_config, step="prepare")
QATConfig(weight_config=fq_config, step="prepare")

# OK: good step values
self.assertEqual(QATConfig(base_config).step, "prepare")
self.assertEqual(QATConfig(base_config, step="Prepare").step, "prepare")
self.assertEqual(QATConfig(base_config, step="CONVERT").step, "convert")

# Bad step
with self.assertRaisesRegex(ValueError, "`step` must be one of"):
QATConfig(base_config, step="blah")

# Step was not a keyword arg
with self.assertRaisesRegex(
TypeError, "4 positional arguments but 5 were given"
):
QATConfig(base_config, None, None, "prepare")

# No configs are provided
with self.assertRaisesRegex(
ValueError, "One of `base_config` or `weight_config` must be specified"
):
QATConfig(step="prepare")

# Clashing configs are provided
with self.assertRaisesRegex(ValueError, "Cannot specify both"):
QATConfig(base_config, weight_config=fq_config, step="prepare")
with self.assertRaisesRegex(ValueError, "Cannot specify both"):
QATConfig(base_config, activation_config=fq_config, step="prepare")
with self.assertRaisesRegex(
ValueError, "must be specified in the convert step"
):
QATConfig(weight_config=fq_config, step="convert")

# FakeQuantizeConfigBase was specified as base_config
with self.assertRaisesRegex(
ValueError,
"was passed as `base_config`. Did you mean to do the following instead?",
):
QATConfig(fq_config, step="prepare")

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api_prepare(self):
"""
Test that the following:

quantize_(model, IntXQuantizationAwareTrainingConfig(...))
quantize_(model, QATConfig(...))

can produce the same results as `ComposableQATQuantizer`.
"""
Expand All @@ -1290,20 +1348,15 @@ def test_quantize_api_standalone(self):
baseline_model = baseline_quantizer.prepare(baseline_model)

# quantize_ API
activation_config = IntxFakeQuantizeConfig(
torch.int8,
"per_token",
is_symmetric=False,
)
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
quantize_(
m,
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
qat_config1 = QATConfig(
activation_config=act_config, weight_config=weight_config
)
qat_config2 = QATConfig(weight_config=weight_config)
quantize_(m, qat_config1)
quantize_(
m,
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
m, qat_config2, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding)
)

# Compare model values
Expand All @@ -1322,37 +1375,29 @@ def test_quantize_api_errors(self):
Test that we throw exceptions with helpful error messages if `quantize_`
runs into unexpected configurations.
"""
my_config = IntxFakeQuantizeConfig(torch.int8, group_size=32)
fq_config = IntxFakeQuantizeConfig(torch.int8, group_size=32)
qat_config = QATConfig(activation_config=fq_config, weight_config=fq_config)
m = M3()

# Embedding currently only supports weight-only quantization
with self.assertRaisesRegex(
ValueError, "Activation fake quantization is not supported for embedding"
):
quantize_(
m,
IntXQuantizationAwareTrainingConfig(my_config, my_config),
lambda m, _: isinstance(m, torch.nn.Embedding),
)
quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.Embedding))

# Only linear and embedding are supported currently
with self.assertRaisesRegex(ValueError, "does not have QAT support"):
quantize_(
m,
IntXQuantizationAwareTrainingConfig(my_config, my_config),
lambda m, _: isinstance(m, torch.nn.ReLU),
)
quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.ReLU))

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api_convert_path(self):
def test_quantize_api_e2e(self):
"""
Test that the following:

quantize_(model, IntXQuantizationAwareTrainingConfig(...))
quantize_(model, FromIntXQuantizationAwareTrainingConfig(...))
quantize_(model, int8_dynamic_activation_int4_weight())
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))

can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert.
"""
Expand All @@ -1370,16 +1415,8 @@ def test_quantize_api_convert_path(self):
baseline_model = baseline_quantizer.prepare(baseline_model)

# quantize_ prepare
activation_config = IntxFakeQuantizeConfig(
torch.int8,
"per_token",
is_symmetric=False,
)
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
quantize_(
m,
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
)
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
quantize_(m, QATConfig(base_config, step="prepare"))

# Compare prepared values
torch.manual_seed(self.SEED)
Expand All @@ -1393,8 +1430,7 @@ def test_quantize_api_convert_path(self):
baseline_model = baseline_quantizer.convert(baseline_model)

# quantize_ convert
quantize_(m, FromIntXQuantizationAwareTrainingConfig())
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))
quantize_(m, QATConfig(base_config, step="convert"))

# Compare converted values
torch.manual_seed(self.SEED)
Expand Down Expand Up @@ -1447,14 +1483,12 @@ def test_qat_linear_bias(self):
Test that QAT supports linear bias.
"""
m = ModelWithLinearBias()
activation_config = IntxFakeQuantizeConfig(
torch.int8, "per_token", is_symmetric=False
)
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=32)
quantize_(
m,
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
qat_config = QATConfig(
activation_config=act_config, weight_config=weight_config
)
quantize_(m, qat_config)
example_inputs = m.example_inputs()
m(*example_inputs)

Expand Down Expand Up @@ -1653,7 +1687,7 @@ def test_qat_range_learning(self):
)
m = M()
example_inputs = m.example_inputs()
quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config))
quantize_(m, QATConfig(weight_config=config))

# Not initialized, should fail
for t in m._get_all_weight_qparams():
Expand Down Expand Up @@ -1756,6 +1790,60 @@ def test_qat_fp8a4w_quantizer(self):
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
self.assertFalse(torch.equal(new_weight, prev_weight))

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_legacy_quantize_api_e2e(self):
"""
Test that the following two APIs are numerically equivalent:

New API:
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))

Old API:
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig())
"""
group_size = 16
torch.manual_seed(self.SEED)
m = M()
baseline_model = copy.deepcopy(m)

# Baseline prepare
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config)
quantize_(baseline_model, old_qat_config)

# QATConfig prepare
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
quantize_(m, QATConfig(base_config, step="prepare"))

# Compare prepared values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
out = m(*x)
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)

# Baseline convert
quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig())
quantize_(baseline_model, base_config)

# quantize_ convert
quantize_(m, QATConfig(base_config, step="convert"))

# Compare converted values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
out = m(*x)
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
Loading
Loading