Skip to content

Commit 058d99a

Browse files
committed
New multi-step QAT API
**Summary:** This commit adds a new multi-step QAT API with the main goal of simplifying the existing UX. The new API uses the same `QATConfig` for both the prepare and convert steps, and automatically infers the fake quantization configs based on a PTQ base config provided by the user: ``` from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig ) from torchao.quantization.qat import QATConfig \# prepare base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) qat_config = QATConfig(base_config, step="prepare") quantize_(m, qat_config) \# train (not shown) \# convert quantize_(m, QATConfig(base_config, step="convert")) ``` The main improvements include: - A single config for both prepare and convert steps - A single quantize_ for convert (instead of 2) - No chance for incompatible prepare vs convert configs - Much less boilerplate code for most common use case - Simpler config names For less common use cases such as experimentation, users can still specify arbitrary fake quantization configs for activations and/or weights as before. This is still important since there may not always be a corresponding PTQ base config. For example: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) \# train and convert same as above (not shown) ``` **BC-breaking notes:** This change by itself is technically not BC-breaking since we keep around the old path, but will become so when we deprecate and remove the old path in the future. Before: ``` \# prepare activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(model, qat_config) \# train (not shown) \# convert quantize_(model, FromIntXQuantizationAwareTrainingConfig()) quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` After: (see above) **Test Plan:** ``` python test/quantization/test_qat.py ``` ghstack-source-id: 7adbc7c Pull Request resolved: #2629
1 parent a4e0235 commit 058d99a

File tree

9 files changed

+468
-141
lines changed

9 files changed

+468
-141
lines changed

README.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,17 @@ With this quantization flow, we achieve **67% VRAM reduction and 12-20% speedup*
179179
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with [TorchTune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the [QAT README](torchao/quantization/qat/README.md) and the [original blog](https://pytorch.org/blog/quantization-aware-training/):
180180

181181
```python
182-
from torchao.quantization import quantize_
183-
from torchao.quantization.qat import IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig
184-
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
185-
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
186-
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
187-
quantize_(my_model, qat_config)
182+
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
183+
from torchao.quantization.qat import QATConfig
184+
185+
# prepare
186+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
187+
quantize_(my_model, QATConfig(base_config, step="prepare"))
188+
189+
# train model (not shown)
190+
191+
# convert
192+
quantize_(my_model, QATConfig(base_config, step="convert"))
188193
```
189194

190195
Users can also combine LoRA + QAT to speed up training by [1.89x](https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700) compared to vanilla QAT using this [fine-tuning recipe](https://github.com/pytorch/torchtune/blob/main/recipes/qat_lora_finetune_distributed.py).

docs/source/api_ref_qat.rst

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ torchao.quantization.qat
66

77
.. currentmodule:: torchao.quantization.qat
88

9-
QAT Configs for quantize_
9+
Main Config for quantize_
1010
---------------------------------------
1111
For a full example of how to use QAT with our main `quantize_` API,
1212
please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md#quantize_-api-recommended>`__.
@@ -15,29 +15,32 @@ please refer to the `QAT README <https://github.com/pytorch/ao/blob/main/torchao
1515
:toctree: generated/
1616
:nosignatures:
1717

18-
IntXQuantizationAwareTrainingConfig
19-
FromIntXQuantizationAwareTrainingConfig
18+
QATConfig
19+
2020

2121
Custom QAT APIs
2222
---------------
2323
.. autosummary::
2424
:toctree: generated/
2525
:nosignatures:
2626

27+
FakeQuantizeConfigBase
2728
IntxFakeQuantizeConfig
2829
FakeQuantizedLinear
2930
FakeQuantizedEmbedding
3031
FakeQuantizer
3132
linear.enable_linear_fake_quant
3233
linear.disable_linear_fake_quant
3334

34-
Legacy QAT Quantizers
35+
Legacy QAT APIs
3536
---------------------
3637

3738
.. autosummary::
3839
:toctree: generated/
3940
:nosignatures:
4041

42+
IntXQuantizationAwareTrainingConfig
43+
FromIntXQuantizationAwareTrainingConfig
4144
Int4WeightOnlyQATQuantizer
4245
linear.Int4WeightOnlyQATLinear
4346
Int8DynActInt4WeightQATQuantizer

docs/source/finetuning.rst

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -205,21 +205,14 @@ because we are not actually casting the fake quantized values.
205205

206206
.. code:: py
207207
208-
from torchao.quantization import (
209-
quantize_,
210-
)
211-
from torchao.quantization.qat import (
212-
FakeQuantizeConfig,
213-
IntXQuantizationAwareTrainingConfig,
214-
)
208+
from torchao.quantization import quantize_, Int8DynamicActivationInt4WeightConfig
209+
from torchao.quantization.qat import QATConfig
210+
215211
model = get_model()
216212
217-
# prepare: insert fake quantization ops
218-
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
219-
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
220-
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
221-
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
222-
quantize_(model, qat_config)
213+
# prepare: swap `torch.nn.Linear` -> `FakeQuantizedLinear`
214+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
215+
quantize_(model, QATConfig(base_config, step="prepare"))
223216
224217
# fine-tune
225218
train_loop(model)
@@ -232,18 +225,12 @@ The next step is to actually quantize the model:
232225

233226
.. code:: py
234227
235-
from torchao.quantization import (
236-
Int8DynamicActivationInt4WeightConfig,
237-
)
238-
from torchao.quantization.qat import (
239-
FromIntXQuantizationAwareTrainingConfig,
240-
)
228+
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
241229
242-
# convert: transform fake quantization ops into actual quantized ops
243-
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
244-
# quantized activation and weight tensor subclasses
245-
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
246-
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
230+
# convert: swap `FakeQuantizedLinear` -> `torch.nn.Linear`, then quantize using `base_config`
231+
quantize_(model, QATConfig(base_config, step="convert"))
232+
233+
# inference or generate
247234
248235
Now our model is ready for serving, and will typically have higher quantized
249236
accuracy than if we did not apply the prepare step (fake quantization) during

test/quantization/test_qat.py

Lines changed: 135 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ComposableQATQuantizer,
3535
FromIntXQuantizationAwareTrainingConfig,
3636
IntXQuantizationAwareTrainingConfig,
37+
QATConfig,
3738
initialize_fake_quantizers,
3839
)
3940
from torchao.quantization.qat.embedding import (
@@ -59,7 +60,7 @@
5960
_get_qmin_qmax,
6061
)
6162
from torchao.quantization.quant_api import (
62-
int8_dynamic_activation_int4_weight,
63+
Int8DynamicActivationInt4WeightConfig,
6364
)
6465
from torchao.quantization.quant_primitives import (
6566
MappingType,
@@ -1261,11 +1262,67 @@ def test_qat_prototype_bc(self):
12611262
@unittest.skipIf(
12621263
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
12631264
)
1264-
def test_quantize_api_standalone(self):
1265+
def test_qat_config_init(self):
1266+
"""
1267+
Test that the correct errors are thrown if `QATConfig` is not instantiated properly.
1268+
"""
1269+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
1270+
fq_config = IntxFakeQuantizeConfig(torch.int8, "per_channel")
1271+
1272+
# OK
1273+
QATConfig(base_config, step="prepare")
1274+
QATConfig(base_config, step="convert")
1275+
QATConfig(activation_config=fq_config, weight_config=fq_config, step="prepare")
1276+
QATConfig(weight_config=fq_config, step="prepare")
1277+
1278+
# OK: good step values
1279+
self.assertEqual(QATConfig(base_config).step, "prepare")
1280+
self.assertEqual(QATConfig(base_config, step="Prepare").step, "prepare")
1281+
self.assertEqual(QATConfig(base_config, step="CONVERT").step, "convert")
1282+
1283+
# Bad step
1284+
with self.assertRaisesRegex(
1285+
ValueError, "`step` must be either 'prepare' or 'convert'"
1286+
):
1287+
QATConfig(base_config, step="blah")
1288+
1289+
# Step was not a keyword arg
1290+
with self.assertRaisesRegex(
1291+
TypeError, "4 positional arguments but 5 were given"
1292+
):
1293+
QATConfig(base_config, None, None, "prepare")
1294+
1295+
# No configs are provided
1296+
with self.assertRaisesRegex(
1297+
ValueError, "One of `base_config` or `weight_config` must be specified"
1298+
):
1299+
QATConfig(step="prepare")
1300+
1301+
# Clashing configs are provided
1302+
with self.assertRaisesRegex(ValueError, "Cannot specify both"):
1303+
QATConfig(base_config, weight_config=fq_config, step="prepare")
1304+
with self.assertRaisesRegex(ValueError, "Cannot specify both"):
1305+
QATConfig(base_config, activation_config=fq_config, step="prepare")
1306+
with self.assertRaisesRegex(
1307+
ValueError, "must be specified in the convert step"
1308+
):
1309+
QATConfig(weight_config=fq_config, step="convert")
1310+
1311+
# FakeQuantizeConfigBase was specified as base_config
1312+
with self.assertRaisesRegex(
1313+
ValueError,
1314+
"was passed as `base_config`. Did you mean to do the following instead?",
1315+
):
1316+
QATConfig(fq_config, step="prepare")
1317+
1318+
@unittest.skipIf(
1319+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1320+
)
1321+
def test_quantize_api_prepare(self):
12651322
"""
12661323
Test that the following:
12671324
1268-
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1325+
quantize_(model, QATConfig(...))
12691326
12701327
can produce the same results as `ComposableQATQuantizer`.
12711328
"""
@@ -1290,20 +1347,15 @@ def test_quantize_api_standalone(self):
12901347
baseline_model = baseline_quantizer.prepare(baseline_model)
12911348

12921349
# quantize_ API
1293-
activation_config = IntxFakeQuantizeConfig(
1294-
torch.int8,
1295-
"per_token",
1296-
is_symmetric=False,
1297-
)
1350+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
12981351
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1299-
quantize_(
1300-
m,
1301-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1352+
qat_config1 = QATConfig(
1353+
activation_config=act_config, weight_config=weight_config
13021354
)
1355+
qat_config2 = QATConfig(weight_config=weight_config)
1356+
quantize_(m, qat_config1)
13031357
quantize_(
1304-
m,
1305-
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
1306-
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
1358+
m, qat_config2, filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding)
13071359
)
13081360

13091361
# Compare model values
@@ -1322,37 +1374,29 @@ def test_quantize_api_errors(self):
13221374
Test that we throw exceptions with helpful error messages if `quantize_`
13231375
runs into unexpected configurations.
13241376
"""
1325-
my_config = IntxFakeQuantizeConfig(torch.int8, group_size=32)
1377+
fq_config = IntxFakeQuantizeConfig(torch.int8, group_size=32)
1378+
qat_config = QATConfig(activation_config=fq_config, weight_config=fq_config)
13261379
m = M3()
13271380

13281381
# Embedding currently only supports weight-only quantization
13291382
with self.assertRaisesRegex(
13301383
ValueError, "Activation fake quantization is not supported for embedding"
13311384
):
1332-
quantize_(
1333-
m,
1334-
IntXQuantizationAwareTrainingConfig(my_config, my_config),
1335-
lambda m, _: isinstance(m, torch.nn.Embedding),
1336-
)
1385+
quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.Embedding))
13371386

13381387
# Only linear and embedding are supported currently
13391388
with self.assertRaisesRegex(ValueError, "does not have QAT support"):
1340-
quantize_(
1341-
m,
1342-
IntXQuantizationAwareTrainingConfig(my_config, my_config),
1343-
lambda m, _: isinstance(m, torch.nn.ReLU),
1344-
)
1389+
quantize_(m, qat_config, lambda m, _: isinstance(m, torch.nn.ReLU))
13451390

13461391
@unittest.skipIf(
13471392
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
13481393
)
1349-
def test_quantize_api_convert_path(self):
1394+
def test_quantize_api_e2e(self):
13501395
"""
13511396
Test that the following:
13521397
1353-
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1354-
quantize_(model, FromIntXQuantizationAwareTrainingConfig(...))
1355-
quantize_(model, int8_dynamic_activation_int4_weight())
1398+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1399+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
13561400
13571401
can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert.
13581402
"""
@@ -1370,16 +1414,8 @@ def test_quantize_api_convert_path(self):
13701414
baseline_model = baseline_quantizer.prepare(baseline_model)
13711415

13721416
# quantize_ prepare
1373-
activation_config = IntxFakeQuantizeConfig(
1374-
torch.int8,
1375-
"per_token",
1376-
is_symmetric=False,
1377-
)
1378-
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1379-
quantize_(
1380-
m,
1381-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1382-
)
1417+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1418+
quantize_(m, QATConfig(base_config, step="prepare"))
13831419

13841420
# Compare prepared values
13851421
torch.manual_seed(self.SEED)
@@ -1393,8 +1429,7 @@ def test_quantize_api_convert_path(self):
13931429
baseline_model = baseline_quantizer.convert(baseline_model)
13941430

13951431
# quantize_ convert
1396-
quantize_(m, FromIntXQuantizationAwareTrainingConfig())
1397-
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))
1432+
quantize_(m, QATConfig(base_config, step="convert"))
13981433

13991434
# Compare converted values
14001435
torch.manual_seed(self.SEED)
@@ -1447,14 +1482,12 @@ def test_qat_linear_bias(self):
14471482
Test that QAT supports linear bias.
14481483
"""
14491484
m = ModelWithLinearBias()
1450-
activation_config = IntxFakeQuantizeConfig(
1451-
torch.int8, "per_token", is_symmetric=False
1452-
)
1485+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
14531486
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=32)
1454-
quantize_(
1455-
m,
1456-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
1487+
qat_config = QATConfig(
1488+
activation_config=act_config, weight_config=weight_config
14571489
)
1490+
quantize_(m, qat_config)
14581491
example_inputs = m.example_inputs()
14591492
m(*example_inputs)
14601493

@@ -1653,7 +1686,7 @@ def test_qat_range_learning(self):
16531686
)
16541687
m = M()
16551688
example_inputs = m.example_inputs()
1656-
quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config))
1689+
quantize_(m, QATConfig(weight_config=config))
16571690

16581691
# Not initialized, should fail
16591692
for t in m._get_all_weight_qparams():
@@ -1756,6 +1789,60 @@ def test_qat_fp8a4w_quantizer(self):
17561789
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
17571790
self.assertFalse(torch.equal(new_weight, prev_weight))
17581791

1792+
@unittest.skipIf(
1793+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1794+
)
1795+
def test_legacy_quantize_api_e2e(self):
1796+
"""
1797+
Test that the following two APIs are numerically equivalent:
1798+
1799+
New API:
1800+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1801+
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
1802+
1803+
Old API:
1804+
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1805+
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
1806+
quantize_(model, Int8DynamicActivationInt4WeightConfig())
1807+
"""
1808+
group_size = 16
1809+
torch.manual_seed(self.SEED)
1810+
m = M()
1811+
baseline_model = copy.deepcopy(m)
1812+
1813+
# Baseline prepare
1814+
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
1815+
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1816+
old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config)
1817+
quantize_(baseline_model, old_qat_config)
1818+
1819+
# QATConfig prepare
1820+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1821+
quantize_(m, QATConfig(base_config, step="prepare"))
1822+
1823+
# Compare prepared values
1824+
torch.manual_seed(self.SEED)
1825+
x = m.example_inputs()
1826+
x2 = copy.deepcopy(x)
1827+
out = m(*x)
1828+
baseline_out = baseline_model(*x2)
1829+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1830+
1831+
# Baseline convert
1832+
quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig())
1833+
quantize_(baseline_model, base_config)
1834+
1835+
# quantize_ convert
1836+
quantize_(m, QATConfig(base_config, step="convert"))
1837+
1838+
# Compare converted values
1839+
torch.manual_seed(self.SEED)
1840+
x = m.example_inputs()
1841+
x2 = copy.deepcopy(x)
1842+
out = m(*x)
1843+
baseline_out = baseline_model(*x2)
1844+
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1845+
17591846

17601847
if __name__ == "__main__":
17611848
unittest.main()

0 commit comments

Comments
 (0)