Skip to content

Deprecate old QAT APIs #2641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Aug 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c358b1b
[bc-breaking] Generalize FakeQuantizeConfig beyond intx
andrewor14 Jul 29, 2025
d076264
New multi-step QAT API
andrewor14 Jul 29, 2025
8f56651
Update on "New multi-step QAT API"
andrewor14 Jul 29, 2025
7a9fe90
Update on "New multi-step QAT API"
andrewor14 Jul 30, 2025
1e88ebf
Deprecate old QAT APIs
andrewor14 Jul 30, 2025
12e8c3f
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
2ed5e50
Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
019b665
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
b0c4721
Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
0eb0983
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
c91b218
Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
affc74e
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
3069075
Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
b41f4e7
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
12d920b
Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
68728d7
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
52f72a5
Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
56415d6
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
08f87af
Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
be45ff4
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
28b3b41
Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
9baae23
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Aug 1, 2025
62cd942
Update on "Deprecate old QAT APIs"
andrewor14 Aug 1, 2025
1c30bbb
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Aug 1, 2025
2fbfbb6
Update on "Deprecate old QAT APIs"
andrewor14 Aug 1, 2025
3f06429
Merge branch 'main' into gh/andrewor14/15/head
andrewor14 Aug 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import copy
import unittest
import warnings
from typing import List

import torch
Expand Down Expand Up @@ -1844,6 +1845,45 @@ def test_legacy_quantize_api_e2e(self):
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_qat_api_deprecation(self):
"""
Test that the appropriate deprecation warning is logged exactly once per class.
"""
from torchao.quantization.qat import (
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)

# Reset deprecation warning state, otherwise we won't log warnings here
warnings.resetwarnings()

# Map from deprecated API to the args needed to instantiate it
deprecated_apis_to_args = {
IntXQuantizationAwareTrainingConfig: (),
FromIntXQuantizationAwareTrainingConfig: (),
intx_quantization_aware_training: (),
from_intx_quantization_aware_training: (),
FakeQuantizeConfig: (torch.int8, "per_channel"),
}

with warnings.catch_warnings(record=True) as _warnings:
# Call each deprecated API twice
for cls, args in deprecated_apis_to_args.items():
cls(*args)
cls(*args)

# Each call should trigger the warning only once
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
for w in _warnings:
self.assertIn(
"is deprecated and will be removed in a future release",
str(w.message),
)


if __name__ == "__main__":
unittest.main()
22 changes: 15 additions & 7 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_infer_fake_quantize_configs,
)
from .linear import FakeQuantizedLinear
from .utils import _log_deprecation_warning


class QATStep(str, Enum):
Expand Down Expand Up @@ -224,11 +225,11 @@ def _qat_config_transform(
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config)


# TODO: deprecate
@dataclass
class IntXQuantizationAwareTrainingConfig(AOBaseConfig):
"""
(Will be deprecated soon)
(Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead.

Config for applying fake quantization to a `torch.nn.Module`.
to be used with :func:`~torchao.quantization.quant_api.quantize_`.

Expand Down Expand Up @@ -256,9 +257,13 @@ class IntXQuantizationAwareTrainingConfig(AOBaseConfig):
activation_config: Optional[FakeQuantizeConfigBase] = None
weight_config: Optional[FakeQuantizeConfigBase] = None

def __post_init__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I heard typing_extensions.deprecated is more IDE friendly: pytorch/pytorch#153892 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I don't think torchao has a dependency on typing_extensions? Do we want to add a new one for this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, OK probably not worth to add deps since it looks like right now there is no deps for torchao

_log_deprecation_warning(self)


# for BC
intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig
class intx_quantization_aware_training(IntXQuantizationAwareTrainingConfig):
pass


@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
Expand Down Expand Up @@ -286,10 +291,11 @@ def _intx_quantization_aware_training_transform(
raise ValueError("Module of type '%s' does not have QAT support" % type(mod))


# TODO: deprecate
@dataclass
class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig):
"""
(Will be deprecated soon)
(Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead.

Config for converting a model with fake quantized modules,
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
Expand All @@ -306,11 +312,13 @@ class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig):
)
"""

pass
def __post_init__(self):
_log_deprecation_warning(self)


# for BC
from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig
class from_intx_quantization_aware_training(FromIntXQuantizationAwareTrainingConfig):
pass


@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig)
Expand Down
18 changes: 17 additions & 1 deletion torchao/quantization/qat/fake_quantize_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
ZeroPointDomain,
)

from .utils import _log_deprecation_warning


class FakeQuantizeConfigBase(abc.ABC):
"""
Expand Down Expand Up @@ -134,6 +136,14 @@ def __init__(
if is_dynamic and range_learning:
raise ValueError("`is_dynamic` is not compatible with `range_learning`")

self.__post_init__()

def __post_init__(self):
"""
For deprecation only, can remove after https://github.com/pytorch/ao/issues/2630.
"""
pass

def _get_granularity(
self,
granularity: Union[Granularity, str, None],
Expand Down Expand Up @@ -260,7 +270,13 @@ def __setattr__(self, name: str, value: Any):


# For BC
FakeQuantizeConfig = IntxFakeQuantizeConfig
class FakeQuantizeConfig(IntxFakeQuantizeConfig):
"""
(Deprecated) Please use :class:`~torchao.quantization.qat.IntxFakeQuantizeConfig` instead.
"""

def __post_init__(self):
_log_deprecation_warning(self)


def _infer_fake_quantize_configs(
Expand Down
32 changes: 32 additions & 0 deletions torchao/quantization/qat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Any

import torch

Expand Down Expand Up @@ -104,3 +106,33 @@ def _get_qmin_qmax(n_bit: int, symmetric: bool = True):
qmin = 0
qmax = 2**n_bit - 1
return (qmin, qmax)


def _log_deprecation_warning(old_api_object: Any):
"""
Log a helpful deprecation message pointing users to the new QAT API,
only once per deprecated class.
"""
warnings.warn(
"""'%s' is deprecated and will be removed in a future release. Please use the following API instead:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a concrete time for deprecation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's in the issue, deprecated this version (0.13.0) and removed the next (0.14.0)


base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(model, QATConfig(base_config, step="prepare"))
# train (not shown)
quantize_(model, QATConfig(base_config, step="convert"))

Alternatively, if you prefer to pass in fake quantization configs:

activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32)
qat_config = QATConfig(
activation_config=activation_config,
weight_config=weight_config,
step="prepare",
)
quantize_(model, qat_config)

Please see https://github.com/pytorch/ao/issues/2630 for more details.
"""
% old_api_object.__class__.__name__
)
Loading