Skip to content

Commit eff0d82

Browse files
committed
New multi-step QAT API
**Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ``` from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig \# prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) qat_config = QATConfig(base_config, step="prepare") quantize_(m, qat_config) \# train (not shown) \# convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) \# train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ``` \# prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) \# train (not shown) \# convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` ghstack-source-id: 7adbc7c Pull Request resolved: #2629
1 parent a4e0235 commit eff0d82

File tree

9 files changed

+480
-142
lines changed

9 files changed

+480
-142
lines changed

README.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,17 @@ With this quantization flow, we achieve **67% VRAM reduction and 12-20% speedup*
179179
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with [TorchTune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [QAT README](torchao/quantization/qat/README.md) and the [original blog](https://pytorch.org/blog/quantization-aware-training/):
180180

181181
```python
182-
from torchao.quantization import quantize_
183-
from torchao.quantization.qat import IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig
184-
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
185-
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
186-
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
187-
quantize_(my_model, qat_config)
182+
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
183+
from torchao.quantization.qat import QATConfig
184+
185+
# prepare
186+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
187+
quantize_(my_model, QATConfig(base_config, step="prepare"))
188+
189+
# train model (not shown)
190+
191+
# convert
192+
quantize_(my_model, QATConfig(base_config, step="convert"))
188193
```
189194

190195
Users can also combine LoRA + QAT to speed up training by [1.89x](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700) compared to vanilla QAT using this [fine-tuning recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py).

docs/source/api_ref_qat.rst

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ torchao.quantization.qat
66

77
.. currentmodule:: torchao.quantization.qat
88

9-
QAT Configs for quantize_
9+
Main Config for quantize_
1010
---------------------------------------
1111
For a full example of how to use QAT with our main `quantize_` API,
1212
please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md#quantize_-api-recommended>`__.
@@ -15,29 +15,32 @@ please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao
1515
:toctree: generated/
1616
:nosignatures:
1717

18-
IntXQuantizationAwareTrainingConfig
19-
FromIntXQuantizationAwareTrainingConfig
18+
QATConfig
19+
QATConfigStep
2020

2121
Custom QAT APIs
2222
---------------
2323
.. autosummary::
2424
:toctree: generated/
2525
:nosignatures:
2626

27+
FakeQuantizeConfigBase
2728
IntxFakeQuantizeConfig
2829
FakeQuantizedLinear
2930
FakeQuantizedEmbedding
3031
FakeQuantizer
3132
linear.enable_linear_fake_quant
3233
linear.disable_linear_fake_quant
3334

34-
Legacy QAT Quantizers
35+
Legacy QAT APIs
3536
---------------------
3637

3738
.. autosummary::
3839
:toctree: generated/
3940
:nosignatures:
4041

42+
IntXQuantizationAwareTrainingConfig
43+
FromIntXQuantizationAwareTrainingConfig
4144
Int4WeightOnlyQATQuantizer
4245
linear.Int4WeightOnlyQATLinear
4346
Int8DynActInt4WeightQATQuantizer

docs/source/finetuning.rst

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -205,21 +205,14 @@ because we are not actually casting the fake quantized values.
205205

206206
.. code:: py
207207
208-
from torchao.quantization import (
209-
quantize_,
210-
)
211-
from torchao.quantization.qat import (
212-
FakeQuantizeConfig,
213-
IntXQuantizationAwareTrainingConfig,
214-
)
208+
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
209+
from torchao.quantization.qat import QATConfig
210+
215211
model = get_model()
216212
217-
# prepare: insert fake quantization ops
218-
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
219-
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
220-
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
221-
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
222-
quantize_(model, qat_config)
213+
# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear`
214+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
215+
quantize_(model, QATConfig(base_config, step="prepare"))
223216
224217
# fine-tune
225218
train_loop(model)
@@ -232,18 +225,12 @@ The next step is to actually quantize the model:
232225

233226
.. code:: py
234227
235-
from torchao.quantization import (
236-
Int8DynamicActivationInt4WeightConfig,
237-
)
238-
from torchao.quantization.qat import (
239-
FromIntXQuantizationAwareTrainingConfig,
240-
)
228+
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
241229
242-
# convert: transform fake quantization ops into actual quantized ops
243-
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
244-
# quantized activation and weight tensor subclasses
245-
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
246-
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
230+
# convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config`
231+
quantize_(model, QATConfig(base_config, step="convert"))
232+
233+
# inference or generate
247234
248235
Now our model is ready for serving, and will typically have higher quantized
249236
accuracy than if we did not apply the prepare step (fake quantization) during

test/quantization/test_qat.py

Lines changed: 133 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ComposableQATQuantizer,
3535
FromIntXQuantizationAwareTrainingConfig,
3636
IntXQuantizationAwareTrainingConfig,
37+
QATConfig,
3738
initialize_fake_quantizers,
3839
)
3940
from torchao.quantization.qat.embedding import (
@@ -59,7 +60,7 @@
5960
_get_qmin_qmax,
6061
)
6162
from torchao.quantization.quant_api import (
62-
int8_dynamic_activation_int4_weight,
63+
Int8DynamicActivationInt4WeightConfig,
6364
)
6465
from torchao.quantization.quant_primitives import (
6566
MappingType,
@@ -1261,11 +1262,65 @@ def test_qat_prototype_bc(self):
12611262
@unittest.skipIf(
12621263
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
12631264
)
1264-
def test_quantize_api_standalone(self):
1265+
def test_qat_config_init(self):
1266+
"""
1267+
Test that the correct errors are thrown if `QATConfig` is not instantiated properly.
1268+
"""
1269+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
1270+
fq_config = IntxFakeQuantizeConfig(torch.int8, "per_channel")
1271+
1272+
# OK
1273+
QATConfig(base_config, step="prepare")
1274+
QATConfig(base_config, step="convert")
1275+
QATConfig(activation_config=fq_config, weight_config=fq_config, step="prepare")
1276+
QATConfig(weight_config=fq_config, step="prepare")
1277+
1278+
# OK: good step values
1279+
self.assertEqual(QATConfig(base_config).step, "prepare")
1280+
self.assertEqual(QATConfig(base_config, step="Prepare").step, "prepare")
1281+
self.assertEqual(QATConfig(base_config, step="CONVERT").step, "convert")
1282+
1283+
# Bad step
1284+
with self.assertRaisesRegex(ValueError, "`step` must be one of"):
1285+
QATConfig(base_config, step="blah")
1286+
1287+
# Step was not a keyword arg
1288+
with self.assertRaisesRegex(
1289+
TypeError, "4 positional arguments but 5 were given"
1290+
):
1291+
QATConfig(base_config, None, None, "prepare")
1292+
1293+
# No configs are provided
1294+
with self.assertRaisesRegex(
1295+
ValueError, "One of `base_config` or `weight_config` must be specified"
1296+
):
1297+
QATConfig(step="prepare")
1298+
1299+
# Clashing configs are provided
1300+
with self.assertRaisesRegex(ValueError, "Cannot specify both"):
1301+
QATConfig(base_config, weight_config=fq_config, step="prepare")
1302+
with self.assertRaisesRegex(ValueError, "Cannot specify both"):
1303+
QATConfig(base_config, activation_config=fq_config, step="prepare")
1304+
with self.assertRaisesRegex(
1305+
ValueError, "must be specified in the convert step"
1306+
):
1307+
QATConfig(weight_config=fq_config, step="convert")
1308+
1309+
# FakeQuantizeConfigBase was specified as base_config
1310+
with self.assertRaisesRegex(
1311+
ValueError,
1312+
"was passed as `base_config`. Did you mean to do the following instead?",
1313+
):
1314+
QATConfig(fq_config, step="prepare")
1315+
1316+
@unittest.skipIf(
1317+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1318+
)
1319+
def test_quantize_api_prepare(self):
12651320
"""
12661321
Test that the following:
12671322
1268-
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1323+
quantize_(model, QATConfig(...))
12691324
12701325
can produce the same results as `ComposableQATQuantizer`.
12711326
"""
@@ -1290,20 +1345,15 @@ def test_quantize_api_standalone(self):
12901345
baseline_model = baseline_quantizer.prepare(baseline_model)
12911346

12921347
# quantize_ API
1293-
activation_config = IntxFakeQuantizeConfig(
1294-
torch.int8,
1295-
"per_token",
1296-
is_symmetric=False,
1297-
)
1348+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
12981349
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1299-
quantize_(
1300-
m,
1301-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1350+
qat_config1 = QATConfig(
1351+
activation_config=act_config, weight_config=weight_config
13021352
)
1353+
qat_config2 = QATConfig(weight_config=weight_config)
1354+
quantize_(m, qat_config1)
13031355
quantize_(
1304-
m,
1305-
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
1306-
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
1356+
m, qat_config2, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding)
13071357
)
13081358

13091359
# Compare model values
@@ -1322,37 +1372,29 @@ def test_quantize_api_errors(self):
13221372
Test that we throw exceptions with helpful error messages if `quantize_`
13231373
runs into unexpected configurations.
13241374
"""
1325-
my_config = IntxFakeQuantizeConfig(torch.int8, group_size=32)
1375+
fq_config = IntxFakeQuantizeConfig(torch.int8, group_size=32)
1376+
qat_config = QATConfig(activation_config=fq_config, weight_config=fq_config)
13261377
m = M3()
13271378

13281379
# Embedding currently only supports weight-only quantization
13291380
with self.assertRaisesRegex(
13301381
ValueError, "Activation fake quantization is not supported for embedding"
13311382
):
1332-
quantize_(
1333-
m,
1334-
IntXQuantizationAwareTrainingConfig(my_config, my_config),
1335-
lambda m, _: isinstance(m, torch.nn.Embedding),
1336-
)
1383+
quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.Embedding))
13371384

13381385
# Only linear and embedding are supported currently
13391386
with self.assertRaisesRegex(ValueError, "does not have QAT support"):
1340-
quantize_(
1341-
m,
1342-
IntXQuantizationAwareTrainingConfig(my_config, my_config),
1343-
lambda m, _: isinstance(m, torch.nn.ReLU),
1344-
)
1387+
quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.ReLU))
13451388

13461389
@unittest.skipIf(
13471390
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
13481391
)
1349-
def test_quantize_api_convert_path(self):
1392+
def test_quantize_api_e2e(self):
13501393
"""
13511394
Test that the following:
13521395
1353-
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1354-
quantize_(model, FromIntXQuantizationAwareTrainingConfig(...))
1355-
quantize_(model, int8_dynamic_activation_int4_weight())
1396+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1397+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
13561398
13571399
can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert.
13581400
"""
@@ -1370,16 +1412,8 @@ def test_quantize_api_convert_path(self):
13701412
baseline_model = baseline_quantizer.prepare(baseline_model)
13711413

13721414
# quantize_ prepare
1373-
activation_config = IntxFakeQuantizeConfig(
1374-
torch.int8,
1375-
"per_token",
1376-
is_symmetric=False,
1377-
)
1378-
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1379-
quantize_(
1380-
m,
1381-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1382-
)
1415+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1416+
quantize_(m, QATConfig(base_config, step="prepare"))
13831417

13841418
# Compare prepared values
13851419
torch.manual_seed(self.SEED)
@@ -1393,8 +1427,7 @@ def test_quantize_api_convert_path(self):
13931427
baseline_model = baseline_quantizer.convert(baseline_model)
13941428

13951429
# quantize_ convert
1396-
quantize_(m, FromIntXQuantizationAwareTrainingConfig())
1397-
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))
1430+
quantize_(m, QATConfig(base_config, step="convert"))
13981431

13991432
# Compare converted values
14001433
torch.manual_seed(self.SEED)
@@ -1447,14 +1480,12 @@ def test_qat_linear_bias(self):
14471480
Test that QAT supports linear bias.
14481481
"""
14491482
m = ModelWithLinearBias()
1450-
activation_config = IntxFakeQuantizeConfig(
1451-
torch.int8, "per_token", is_symmetric=False
1452-
)
1483+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
14531484
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=32)
1454-
quantize_(
1455-
m,
1456-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1485+
qat_config = QATConfig(
1486+
activation_config=act_config, weight_config=weight_config
14571487
)
1488+
quantize_(m, qat_config)
14581489
example_inputs = m.example_inputs()
14591490
m(*example_inputs)
14601491

@@ -1653,7 +1684,7 @@ def test_qat_range_learning(self):
16531684
)
16541685
m = M()
16551686
example_inputs = m.example_inputs()
1656-
quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config))
1687+
quantize_(m, QATConfig(weight_config=config))
16571688

16581689
# Not initialized, should fail
16591690
for t in m._get_all_weight_qparams():
@@ -1756,6 +1787,60 @@ def test_qat_fp8a4w_quantizer(self):
17561787
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
17571788
self.assertFalse(torch.equal(new_weight, prev_weight))
17581789

1790+
@unittest.skipIf(
1791+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1792+
)
1793+
def test_legacy_quantize_api_e2e(self):
1794+
"""
1795+
Test that the following two APIs are numerically equivalent:
1796+
1797+
New API:
1798+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1799+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
1800+
1801+
Old API:
1802+
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1803+
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
1804+
quantize_(model, Int8DynamicActivationInt4WeightConfig())
1805+
"""
1806+
group_size = 16
1807+
torch.manual_seed(self.SEED)
1808+
m = M()
1809+
baseline_model = copy.deepcopy(m)
1810+
1811+
# Baseline prepare
1812+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
1813+
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1814+
old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config)
1815+
quantize_(baseline_model, old_qat_config)
1816+
1817+
# QATConfig prepare
1818+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1819+
quantize_(m, QATConfig(base_config, step="prepare"))
1820+
1821+
# Compare prepared values
1822+
torch.manual_seed(self.SEED)
1823+
x = m.example_inputs()
1824+
x2 = copy.deepcopy(x)
1825+
out = m(*x)
1826+
baseline_out = baseline_model(*x2)
1827+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1828+
1829+
# Baseline convert
1830+
quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig())
1831+
quantize_(baseline_model, base_config)
1832+
1833+
# quantize_ convert
1834+
quantize_(m, QATConfig(base_config, step="convert"))
1835+
1836+
# Compare converted values
1837+
torch.manual_seed(self.SEED)
1838+
x = m.example_inputs()
1839+
x2 = copy.deepcopy(x)
1840+
out = m(*x)
1841+
baseline_out = baseline_model(*x2)
1842+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1843+
17591844

17601845
if __name__ == "__main__":
17611846
unittest.main()

0 commit comments

Comments
 (0)