Skip to content

Commit b5aeafa

Browse files
committed
Deprecate old QAT APIs
**Summary:** Deprecates QAT APIs that should no longer be used. Print helpful deprecation warning to help users migrate. **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_api_deprecation ``` Also manual testing: ``` 'IntXQuantizationAwareTrainingConfig' is deprecated and will be removed in a future release. Please use the following API instead: base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) quantize_(model, QATConfig(base_config, step="prepare")) # train (not shown) quantize_(model, QATConfig(base_config, step="convert")) Alternatively, if you prefer to pass in fake quantization configs: activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = QATConfig( activation_config=activation_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) Please see #2630 for more details. IntXQuantizationAwareTrainingConfig(activation_config=None, weight_config=None) ``` ghstack-source-id: 7ac9f3b Pull Request resolved: #2641
1 parent 7f5a6e4 commit b5aeafa

File tree

5 files changed

+134
-11
lines changed

5 files changed

+134
-11
lines changed

docs/source/api_ref_qat.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,13 @@ Custom QAT APIs
3232
linear.enable_linear_fake_quant
3333
linear.disable_linear_fake_quant
3434

35-
Legacy QAT APIs
35+
Legacy QAT Quantizers
3636
---------------------
3737

3838
.. autosummary::
3939
:toctree: generated/
4040
:nosignatures:
4141

42-
IntXQuantizationAwareTrainingConfig
43-
FromIntXQuantizationAwareTrainingConfig
4442
Int4WeightOnlyQATQuantizer
4543
linear.Int4WeightOnlyQATLinear
4644
Int8DynActInt4WeightQATQuantizer

test/quantization/test_qat.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
# This test takes a long time to run
99

1010
import copy
11+
import io
12+
import logging
1113
import unittest
1214
from typing import List
1315

@@ -1844,6 +1846,64 @@ def test_legacy_quantize_api_e2e(self):
18441846
baseline_out = baseline_model(*x2)
18451847
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
18461848

1849+
def _test_deprecation(self, deprecated_class, *example_args, first_time=True):
1850+
"""
1851+
Assert that instantiating a deprecated class triggers the deprecation warning.
1852+
"""
1853+
try:
1854+
log_stream = io.StringIO()
1855+
handler = logging.StreamHandler(log_stream)
1856+
logger = logging.getLogger(deprecated_class.__module__)
1857+
logger.addHandler(handler)
1858+
logger.setLevel(logging.WARN)
1859+
deprecated_class(*example_args)
1860+
if first_time:
1861+
regex = (
1862+
"'%s' is deprecated and will be removed in a future release"
1863+
% deprecated_class.__name__
1864+
)
1865+
self.assertIn(regex, log_stream.getvalue())
1866+
else:
1867+
self.assertEqual(log_stream.getvalue(), "")
1868+
finally:
1869+
logger.removeHandler(handler)
1870+
handler.close()
1871+
1872+
@unittest.skipIf(
1873+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1874+
)
1875+
def test_qat_api_deprecation(self):
1876+
"""
1877+
Test that the appropriate deprecation warning has been logged.
1878+
"""
1879+
from torchao.quantization.qat import (
1880+
FakeQuantizeConfig,
1881+
from_intx_quantization_aware_training,
1882+
intx_quantization_aware_training,
1883+
)
1884+
from torchao.quantization.qat.utils import _LOGGED_DEPRECATED_CLASSES
1885+
1886+
# Reset deprecation warning state, otherwise we won't log warnings here
1887+
_LOGGED_DEPRECATED_CLASSES.clear()
1888+
1889+
# Assert that the deprecation warning is logged
1890+
self._test_deprecation(IntXQuantizationAwareTrainingConfig)
1891+
self._test_deprecation(FromIntXQuantizationAwareTrainingConfig)
1892+
self._test_deprecation(intx_quantization_aware_training)
1893+
self._test_deprecation(from_intx_quantization_aware_training)
1894+
self._test_deprecation(FakeQuantizeConfig, torch.int8, "per_channel")
1895+
1896+
# Assert that warning is only logged once per class
1897+
self._test_deprecation(IntXQuantizationAwareTrainingConfig, first_time=False)
1898+
self._test_deprecation(
1899+
FromIntXQuantizationAwareTrainingConfig, first_time=False
1900+
)
1901+
self._test_deprecation(intx_quantization_aware_training, first_time=False)
1902+
self._test_deprecation(from_intx_quantization_aware_training, first_time=False)
1903+
self._test_deprecation(
1904+
FakeQuantizeConfig, torch.int8, "per_channel", first_time=False
1905+
)
1906+
18471907

18481908
if __name__ == "__main__":
18491909
unittest.main()

torchao/quantization/qat/api.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_infer_fake_quantize_configs,
2525
)
2626
from .linear import FakeQuantizedLinear
27+
from .utils import _log_deprecation_warning
2728

2829

2930
class QATStep(str, Enum):
@@ -224,11 +225,11 @@ def _qat_config_transform(
224225
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config)
225226

226227

227-
# TODO: deprecate
228228
@dataclass
229229
class IntXQuantizationAwareTrainingConfig(AOBaseConfig):
230230
"""
231-
(Will be deprecated soon)
231+
(Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead.
232+
232233
Config for applying fake quantization to a `torch.nn.Module`.
233234
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
234235
@@ -256,9 +257,13 @@ class IntXQuantizationAwareTrainingConfig(AOBaseConfig):
256257
activation_config: Optional[FakeQuantizeConfigBase] = None
257258
weight_config: Optional[FakeQuantizeConfigBase] = None
258259

260+
def __post_init__(self):
261+
_log_deprecation_warning(self)
262+
259263

260264
# for BC
261-
intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig
265+
class intx_quantization_aware_training(IntXQuantizationAwareTrainingConfig):
266+
pass
262267

263268

264269
@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
@@ -286,10 +291,11 @@ def _intx_quantization_aware_training_transform(
286291
raise ValueError("Module of type '%s' does not have QAT support" % type(mod))
287292

288293

289-
# TODO: deprecate
294+
@dataclass
290295
class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig):
291296
"""
292-
(Will be deprecated soon)
297+
(Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead.
298+
293299
Config for converting a model with fake quantized modules,
294300
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
295301
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
@@ -306,11 +312,13 @@ class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig):
306312
)
307313
"""
308314

309-
pass
315+
def __post_init__(self):
316+
_log_deprecation_warning(self)
310317

311318

312319
# for BC
313-
from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig
320+
class from_intx_quantization_aware_training(FromIntXQuantizationAwareTrainingConfig):
321+
pass
314322

315323

316324
@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig)

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
ZeroPointDomain,
2626
)
2727

28+
from .utils import _log_deprecation_warning
29+
2830

2931
class FakeQuantizeConfigBase(abc.ABC):
3032
"""
@@ -134,6 +136,14 @@ def __init__(
134136
if is_dynamic and range_learning:
135137
raise ValueError("`is_dynamic` is not compatible with `range_learning`")
136138

139+
self.__post_init__()
140+
141+
def __post_init__(self):
142+
"""
143+
For deprecation only, can remove after https://github.com/pytorch/ao/issues/2630.
144+
"""
145+
pass
146+
137147
def _get_granularity(
138148
self,
139149
granularity: Union[Granularity, str, None],
@@ -260,7 +270,13 @@ def __setattr__(self, name: str, value: Any):
260270

261271

262272
# For BC
263-
FakeQuantizeConfig = IntxFakeQuantizeConfig
273+
class FakeQuantizeConfig(IntxFakeQuantizeConfig):
274+
"""
275+
(Deprecated) Please use :class:`~torchao.quantization.qat.IntxFakeQuantizeConfig` instead.
276+
"""
277+
278+
def __post_init__(self):
279+
_log_deprecation_warning(self)
264280

265281

266282
def _infer_fake_quantize_configs(

torchao/quantization/qat/utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
8+
from typing import Any
79

810
import torch
911

@@ -104,3 +106,42 @@ def _get_qmin_qmax(n_bit: int, symmetric: bool = True):
104106
qmin = 0
105107
qmax = 2**n_bit - 1
106108
return (qmin, qmax)
109+
110+
111+
# log deprecation warning only once per class
112+
_LOGGED_DEPRECATED_CLASSES = set[type]()
113+
114+
115+
def _log_deprecation_warning(old_api_object: Any):
116+
"""
117+
Log a helpful deprecation message pointing users to the new QAT API,
118+
only once per deprecated class.
119+
"""
120+
global _LOGGED_DEPRECATED_CLASSES
121+
if old_api_object.__class__ in _LOGGED_DEPRECATED_CLASSES:
122+
return
123+
_LOGGED_DEPRECATED_CLASSES.add(old_api_object.__class__)
124+
logger = logging.getLogger(old_api_object.__module__)
125+
logger.warning(
126+
"""'%s' is deprecated and will be removed in a future release. Please use the following API instead:
127+
128+
base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
129+
quantize_(model, QATConfig(base_config, step="prepare"))
130+
# train (not shown)
131+
quantize_(model, QATConfig(base_config, step="convert"))
132+
133+
Alternatively, if you prefer to pass in fake quantization configs:
134+
135+
activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
136+
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
137+
qat_config = QATConfig(
138+
activation_config=activation_config,
139+
weight_config=weight_config,
140+
step="prepare",
141+
)
142+
quantize_(model, qat_config)
143+
144+
Please see https://github.com/pytorch/ao/issues/2630 for more details.
145+
"""
146+
% old_api_object.__class__.__name__
147+
)

0 commit comments

Comments
 (0)