Skip to content

Commit 8af7f44

Browse files
committed
Deprecate old QAT APIs
**Summary:** Deprecates QAT APIs that should no longer be used. Print helpful deprecation warning to help users migrate. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_api_deprecation ``` Also manual testing: ``` 'IntXQuantizationAwareTrainingConfig' is deprecated and will be removed in a future release. Please use the following API instead: base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) quantize_(model, QATConfig(base_config, step="prepare")) # train (not shown) quantize_(model, QATConfig(base_config, step="convert")) Alternatively, if you prefer to pass in fake quantization configs: activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) Please see #2630 for more details. IntXQuantizationAwareTrainingConfig(activation_config=None, weight_config=None) ``` ghstack-source-id: 7ac9f3b Pull Request resolved: #2641
1 parent 7f5a6e4 commit 8af7f44

File tree

5 files changed

+105
-11
lines changed

5 files changed

+105
-11
lines changed

docs/source/api_ref_qat.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,13 @@ Custom QAT APIs
3232
linear.enable_linear_fake_quant
3333
linear.disable_linear_fake_quant
3434

35-
Legacy QAT APIs
35+
Legacy QAT Quantizers
3636
---------------------
3737

3838
.. autosummary::
3939
:toctree: generated/
4040
:nosignatures:
4141

42-
IntXQuantizationAwareTrainingConfig
43-
FromIntXQuantizationAwareTrainingConfig
4442
Int4WeightOnlyQATQuantizer
4543
linear.Int4WeightOnlyQATLinear
4644
Int8DynActInt4WeightQATQuantizer

test/quantization/test_qat.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import copy
1111
import unittest
12+
import warnings
1213
from typing import List
1314

1415
import torch
@@ -1844,6 +1845,45 @@ def test_legacy_quantize_api_e2e(self):
18441845
baseline_out = baseline_model(*x2)
18451846
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
18461847

1848+
@unittest.skipIf(
1849+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1850+
)
1851+
def test_qat_api_deprecation(self):
1852+
"""
1853+
Test that the appropriate deprecation warning is logged exactly once per class.
1854+
"""
1855+
from torchao.quantization.qat import (
1856+
FakeQuantizeConfig,
1857+
from_intx_quantization_aware_training,
1858+
intx_quantization_aware_training,
1859+
)
1860+
1861+
# Reset deprecation warning state, otherwise we won't log warnings here
1862+
warnings.resetwarnings()
1863+
1864+
# Map from deprecated API to the args needed to instantiate it
1865+
deprecated_apis_to_args = {
1866+
IntXQuantizationAwareTrainingConfig: (),
1867+
FromIntXQuantizationAwareTrainingConfig: (),
1868+
intx_quantization_aware_training: (),
1869+
from_intx_quantization_aware_training: (),
1870+
FakeQuantizeConfig: (torch.int8, "per_channel"),
1871+
}
1872+
1873+
with warnings.catch_warnings(record=True) as _warnings:
1874+
# Call each deprecated API twice
1875+
for cls, args in deprecated_apis_to_args.items():
1876+
cls(*args)
1877+
cls(*args)
1878+
1879+
# Each call should trigger the warning only once
1880+
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
1881+
for w in _warnings:
1882+
self.assertIn(
1883+
"is deprecated and will be removed in a future release",
1884+
str(w.message),
1885+
)
1886+
18471887

18481888
if __name__ == "__main__":
18491889
unittest.main()

torchao/quantization/qat/api.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_infer_fake_quantize_configs,
2525
)
2626
from .linear import FakeQuantizedLinear
27+
from .utils import _log_deprecation_warning
2728

2829

2930
class QATStep(str, Enum):
@@ -224,11 +225,11 @@ def _qat_config_transform(
224225
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config)
225226

226227

227-
# TODO: deprecate
228228
@dataclass
229229
class IntXQuantizationAwareTrainingConfig(AOBaseConfig):
230230
"""
231-
(Will be deprecated soon)
231+
(Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead.
232+
232233
Config for applying fake quantization to a `torch.nn.Module`.
233234
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
234235
@@ -256,9 +257,13 @@ class IntXQuantizationAwareTrainingConfig(AOBaseConfig):
256257
activation_config: Optional[FakeQuantizeConfigBase] = None
257258
weight_config: Optional[FakeQuantizeConfigBase] = None
258259

260+
def __post_init__(self):
261+
_log_deprecation_warning(self)
262+
259263

260264
# for BC
261-
intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig
265+
class intx_quantization_aware_training(IntXQuantizationAwareTrainingConfig):
266+
pass
262267

263268

264269
@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
@@ -286,10 +291,11 @@ def _intx_quantization_aware_training_transform(
286291
raise ValueError("Module of type '%s' does not have QAT support" % type(mod))
287292

288293

289-
# TODO: deprecate
294+
@dataclass
290295
class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig):
291296
"""
292-
(Will be deprecated soon)
297+
(Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead.
298+
293299
Config for converting a model with fake quantized modules,
294300
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
295301
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
@@ -306,11 +312,13 @@ class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig):
306312
)
307313
"""
308314

309-
pass
315+
def __post_init__(self):
316+
_log_deprecation_warning(self)
310317

311318

312319
# for BC
313-
from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig
320+
class from_intx_quantization_aware_training(FromIntXQuantizationAwareTrainingConfig):
321+
pass
314322

315323

316324
@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig)

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
ZeroPointDomain,
2626
)
2727

28+
from .utils import _log_deprecation_warning
29+
2830

2931
class FakeQuantizeConfigBase(abc.ABC):
3032
"""
@@ -134,6 +136,14 @@ def __init__(
134136
if is_dynamic and range_learning:
135137
raise ValueError("`is_dynamic` is not compatible with `range_learning`")
136138

139+
self.__post_init__()
140+
141+
def __post_init__(self):
142+
"""
143+
For deprecation only, can remove after https://github.com/pytorch/ao/issues/2630.
144+
"""
145+
pass
146+
137147
def _get_granularity(
138148
self,
139149
granularity: Union[Granularity, str, None],
@@ -260,7 +270,13 @@ def __setattr__(self, name: str, value: Any):
260270

261271

262272
# For BC
263-
FakeQuantizeConfig = IntxFakeQuantizeConfig
273+
class FakeQuantizeConfig(IntxFakeQuantizeConfig):
274+
"""
275+
(Deprecated) Please use :class:`~torchao.quantization.qat.IntxFakeQuantizeConfig` instead.
276+
"""
277+
278+
def __post_init__(self):
279+
_log_deprecation_warning(self)
264280

265281

266282
def _infer_fake_quantize_configs(

torchao/quantization/qat/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
8+
from typing import Any
79

810
import torch
911

@@ -104,3 +106,33 @@ def _get_qmin_qmax(n_bit: int, symmetric: bool = True):
104106
qmin = 0
105107
qmax = 2**n_bit - 1
106108
return (qmin, qmax)
109+
110+
111+
def _log_deprecation_warning(old_api_object: Any):
112+
"""
113+
Log a helpful deprecation message pointing users to the new QAT API,
114+
only once per deprecated class.
115+
"""
116+
warnings.warn(
117+
"""'%s' is deprecated and will be removed in a future release. Please use the following API instead:
118+
119+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
120+
quantize_(model, QATConfig(base_config, step="prepare"))
121+
# train (not shown)
122+
quantize_(model, QATConfig(base_config, step="convert"))
123+
124+
Alternatively, if you prefer to pass in fake quantization configs:
125+
126+
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
127+
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
128+
qat_config = QATConfig(
129+
activation_config=activation_config,
130+
weight_config=weight_config,
131+
step="prepare",
132+
)
133+
quantize_(model, qat_config)
134+
135+
Please see https://github.com/pytorch/ao/issues/2630 for more details.
136+
"""
137+
% old_api_object.__class__.__name__
138+
)

0 commit comments

Comments
 (0)