From 33a8305c70b884577c13e9b60d40bfdef451d4ff Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Thu, 7 Aug 2025 14:13:42 -0700 Subject: [PATCH] Generalize FakeQuantizer beyond intx **Summary:** Similar to https://github.com/pytorch/ao/pull/2628, but for `FakeQuantizer`. It is cleaner to isolate the logic of each quantizer in separate classes, e.g. intx vs nvfp4 vs fp8. Naming change: ``` FakeQuantizer -> IntxFakeQuantizer ``` **BC-breaking notes:** This is technically not BC-breaking yet since we are just deprecating the old APIs while keeping them around. It will be when we do remove the old APIs in the future according to https://github.com/pytorch/ao/issues/2630. Before: ``` config = IntxFakeQuantizeConfig(torch.int8, "per_channel") FakeQuantizer(config) ``` After: ``` config = IntxFakeQuantizeConfig(torch.int8, "per_channel") IntxFakeQuantizer(config) # or FakeQuantizerBase.from_config(config) ``` **Test Plan:** ``` python test/quantization/test_qat.py ``` [ghstack-poisoned] --- docs/source/api_ref_qat.rst | 3 +- test/quantization/test_qat.py | 20 +++++---- .../prototype/qat/fake_quantizer.py | 2 +- torchao/quantization/qat/__init__.py | 10 ++++- torchao/quantization/qat/api.py | 6 +-- torchao/quantization/qat/embedding.py | 4 +- torchao/quantization/qat/fake_quantizer.py | 45 ++++++++++++++----- torchao/quantization/qat/linear.py | 8 ++-- 8 files changed, 67 insertions(+), 31 deletions(-) diff --git a/docs/source/api_ref_qat.rst b/docs/source/api_ref_qat.rst index bfac8f398d..0179af2f3d 100644 --- a/docs/source/api_ref_qat.rst +++ b/docs/source/api_ref_qat.rst @@ -28,7 +28,8 @@ Custom QAT APIs IntxFakeQuantizeConfig FakeQuantizedLinear FakeQuantizedEmbedding - FakeQuantizer + FakeQuantizerBase + IntxFakeQuantizer linear.enable_linear_fake_quant linear.disable_linear_fake_quant diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 48a9f780b6..bb4bfe7f10 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -46,7 +46,7 @@ IntxFakeQuantizeConfig, ) from torchao.quantization.qat.fake_quantizer import ( - FakeQuantizer, + IntxFakeQuantizer, _Float8RowwiseActivationFakeQuantizer, ) from torchao.quantization.qat.linear import ( @@ -1466,10 +1466,10 @@ def test_fake_quantize_config_torch_intx(self): ) def test_fake_quantizer_repr(self): """ - Test that `repr(FakeQuantizer(config))` exposes useful config details. + Test that `repr(IntxFakeQuantizer(config))` exposes useful config details. """ config = IntxFakeQuantizeConfig(torch.int4, group_size=128) - fake_quantizer = FakeQuantizer(config) + fake_quantizer = IntxFakeQuantizer(config) fake_quantizer_repr = repr(fake_quantizer) self.assertTrue("dtype=torch.int4" in fake_quantizer_repr) self.assertTrue("group_size=128" in fake_quantizer_repr) @@ -1500,7 +1500,7 @@ def test_qat_linear_bias(self): def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype): """ Test that the following produce the exact same numerics: - 1. FakeQuantizer with asymmetric per_token config + 1. IntxFakeQuantizer with asymmetric per_token config 2. torchao.quantization.utils.per_token_dynamic_quant """ from torchao.quantization.utils import per_token_dynamic_quant @@ -1508,7 +1508,7 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype): torch.manual_seed(self.SEED) x = torch.randn(1, 235, 2048).to(dtype) config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) - fake_quantizer = FakeQuantizer(config) + fake_quantizer = IntxFakeQuantizer(config) fake_quantizer_out = fake_quantizer(x) baseline_out = per_token_dynamic_quant(x) torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0) @@ -1580,7 +1580,7 @@ def test_fake_quantize_config_eps(self): is_symmetric=False, eps=eps, ) - fake_quantizer = FakeQuantizer(config) + fake_quantizer = IntxFakeQuantizer(config) actual_out = fake_quantizer(x) torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0) @@ -1638,7 +1638,7 @@ def test_qat_8da4w_eps(self): ) def test_fake_quantizer_range_learning(self): """ - Test that range learning requires `FakeQuantizer`s to be initialized correctly. + Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly. """ config = IntxFakeQuantizeConfig( torch.int8, @@ -1648,7 +1648,7 @@ def test_fake_quantizer_range_learning(self): scale_precision=torch.float32, zero_point_precision=torch.float32, ) - fake_quantizer = FakeQuantizer(config) + fake_quantizer = IntxFakeQuantizer(config) example_inputs = (torch.randn(2, 3),) # Not initialized, should fail @@ -1770,7 +1770,7 @@ def test_qat_fp8a4w_quantizer(self): self.assertIsInstance( linear.activation_fake_quantizer, _Float8RowwiseActivationFakeQuantizer ) - self.assertIsInstance(linear.weight_fake_quantizer, FakeQuantizer) + self.assertIsInstance(linear.weight_fake_quantizer, IntxFakeQuantizer) prev_weight = copy.deepcopy(m.linear1.weight) # Simulate training @@ -1854,6 +1854,7 @@ def test_qat_api_deprecation(self): """ from torchao.quantization.qat import ( FakeQuantizeConfig, + FakeQuantizer, from_intx_quantization_aware_training, intx_quantization_aware_training, ) @@ -1868,6 +1869,7 @@ def test_qat_api_deprecation(self): intx_quantization_aware_training: (), from_intx_quantization_aware_training: (), FakeQuantizeConfig: (torch.int8, "per_channel"), + FakeQuantizer: (IntxFakeQuantizeConfig(torch.int8, "per_channel"),), } with warnings.catch_warnings(record=True) as _warnings: diff --git a/torchao/quantization/prototype/qat/fake_quantizer.py b/torchao/quantization/prototype/qat/fake_quantizer.py index 3bbe1fb704..560a609ce2 100644 --- a/torchao/quantization/prototype/qat/fake_quantizer.py +++ b/torchao/quantization/prototype/qat/fake_quantizer.py @@ -1,5 +1,5 @@ from torchao.quantization.qat.fake_quantizer import ( - FakeQuantizer, + IntxFakeQuantizer as FakeQuantizer, ) __all__ = [ diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 5d3d0996d0..9a7338623d 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -17,7 +17,11 @@ FakeQuantizeConfigBase, IntxFakeQuantizeConfig, ) -from .fake_quantizer import FakeQuantizer +from .fake_quantizer import ( + FakeQuantizer, + FakeQuantizerBase, + IntxFakeQuantizer, +) from .linear import ( FakeQuantizedLinear, Float8ActInt4WeightQATQuantizer, @@ -29,8 +33,9 @@ "QATConfig", "QATStep", "FakeQuantizeConfigBase", + "FakeQuantizerBase", "IntxFakeQuantizeConfig", - "FakeQuantizer", + "IntxFakeQuantizer", "FakeQuantizedLinear", "FakeQuantizedEmbedding", # Prototype @@ -42,6 +47,7 @@ "Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer", # for BC + "FakeQuantizer", "FakeQuantizeConfig", "from_intx_quantization_aware_training", "FromIntXQuantizationAwareTrainingConfig", diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 0d69f44bd9..8273aff343 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -382,14 +382,14 @@ def initialize_fake_quantizers( ) -> None: """ (Prototype) Initialize the scales and zero points on all - :class:`~torchao.quantization.qat.fake_quantizer.FakeQuantizer` + :class:`~torchao.quantization.qat.fake_quantizer.IntxFakeQuantizerBase` in the model based on the provided example inputs. """ # avoid circular dependencies - from torchao.quantization.qat.fake_quantizer import FakeQuantizer + from torchao.quantization.qat.fake_quantizer import IntxFakeQuantizer def _set_initialized(m: torch.nn.Module): - if isinstance(m, FakeQuantizer): + if isinstance(m, IntxFakeQuantizer): m._initialized = True model.apply(_set_initialized) diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index 778ba2b83c..28a3f2cee0 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -17,7 +17,7 @@ FakeQuantizeConfigBase, IntxFakeQuantizeConfig, ) -from .fake_quantizer import FakeQuantizer +from .fake_quantizer import FakeQuantizerBase from .utils import ( _get_qmin_qmax, ) @@ -66,7 +66,7 @@ def __init__( **kwargs, ) if weight_config is not None: - self.weight_fake_quantizer = FakeQuantizer(weight_config) + self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config) else: self.weight_fake_quantizer = None diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index 3cb873f3ff..8c31418ee9 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -34,15 +34,37 @@ _fake_quantize_per_channel_group, _fake_quantize_per_token, _Float8RowwiseFakeQuantize, + _log_deprecation_warning, ) -class FakeQuantizer(torch.nn.Module): +class FakeQuantizerBase(torch.nn.Module): """ Generic module for applying fake quantization to a tensor, as specified in the config. """ - def __init__(self, config: FakeQuantizeConfigBase): + config: FakeQuantizeConfigBase + + def __repr__(self) -> str: + """ + Return a human readable representation of this `FakeQuantizer` with config details. + """ + return "FakeQuantizer(%s)" % self.config + + @staticmethod + def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase": + if isinstance(config, IntxFakeQuantizeConfig): + return IntxFakeQuantizer(config) + else: + raise ValueError(f"Unknown config type: {config}") + + +class IntxFakeQuantizer(FakeQuantizerBase): + """ + Generic module for applying integer fake quantization to a tensor, as specified in the config. + """ + + def __init__(self, config: IntxFakeQuantizeConfig): super().__init__() self.config = config self.enabled = True @@ -62,9 +84,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.enabled: return x - if not isinstance(self.config, IntxFakeQuantizeConfig): - raise ValueError("Only IntxFakeQuantizeConfig is supported currently") - if ( self.config.range_learning and not self._initialized @@ -186,13 +205,19 @@ def _maybe_update_qparams_for_range_learning(self) -> None: self.scale = torch.nn.Parameter(scale, requires_grad=True) self.zero_point = torch.nn.Parameter(zero_point, requires_grad=True) - def __repr__(self) -> str: - """ - Return a human readable representation of this `FakeQuantizer` with config details. - """ - return "FakeQuantizer(%s)" % self.config + +# For BC +class FakeQuantizer(IntxFakeQuantizer): + """ + (Deprecated) Please use :class:`~torchao.quantization.qat.IntxFakeQuantizer` instead. + """ + + def __init__(self, config: FakeQuantizeConfigBase): + super().__init__(config) + _log_deprecation_warning(self) +# TODO: make this a FakeQuantizerBase class _Float8RowwiseActivationFakeQuantizer(torch.nn.Module): """ Simple fake quantizer for float8 rowwise fake quantization, intended for activations only. diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index c9c8f8ea5d..59e759dab3 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -32,7 +32,7 @@ IntxFakeQuantizeConfig, ) from .fake_quantizer import ( - FakeQuantizer, + FakeQuantizerBase, _Float8RowwiseActivationFakeQuantizer, ) from .utils import ( @@ -84,7 +84,9 @@ def __init__( ) # initialize activation fake quantizer if activation_config is not None: - self.activation_fake_quantizer = FakeQuantizer(activation_config) + self.activation_fake_quantizer = FakeQuantizerBase.from_config( + activation_config + ) else: self.activation_fake_quantizer = None @@ -97,7 +99,7 @@ def __init__( "in_features (%s) %% group_size (%s) must be == 0" % (in_features, group_size) ) - self.weight_fake_quantizer = FakeQuantizer(weight_config) + self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config) else: self.weight_fake_quantizer = None