From c358b1b064da8bb6817f3b837d7f629116114687 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 29 Jul 2025 13:02:43 -0700 Subject: [PATCH 1/4] [bc-breaking] Generalize FakeQuantizeConfig beyond intx **Summary:** The existing `FakeQuantizeConfig` performs only intx quantization, but we plan to extend QAT to other dtypes such as fp8 and nvfp4 in the near future. This is the necessary refactor before that. Specifically: ``` # New abstract class FakeQuantizeConfigBase # Rename FakeQuantizeConfig -> IntxFakeQuantizeConfig ``` In the future, we will have other types of `FakeQuantizeConfigBase` for float dtypes that users can pass in instead of the existing Intx one. **BC-breaking notes:** For BC, we keep around the old names to reference the new ones. However, this commit is still BC-breaking in the sense that a few APIs now accept the abstract `FakeQuantizeConfigBase` instead. For the most part, this abstract class will be hidden from the user. Before: ``` activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) ``` After: ``` activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) ``` **Test Plan:** python test/quantization/test_qat.py [ghstack-poisoned] --- README.md | 6 +- docs/source/api_ref_qat.rst | 2 +- test/prototype/test_parq.py | 4 +- test/quantization/test_qat.py | 173 ++++++------ .../tests/test_embedding_xbit_quantizer.py | 4 +- ...est_int8_dynamic_activation_intx_weight.py | 6 +- torchao/quantization/qat/README.md | 12 +- torchao/quantization/qat/__init__.py | 13 +- torchao/quantization/qat/api.py | 252 +---------------- torchao/quantization/qat/embedding.py | 13 +- .../quantization/qat/fake_quantize_config.py | 262 ++++++++++++++++++ torchao/quantization/qat/fake_quantizer.py | 10 +- torchao/quantization/qat/linear.py | 53 ++-- 13 files changed, 433 insertions(+), 377 deletions(-) create mode 100644 torchao/quantization/qat/fake_quantize_config.py diff --git a/README.md b/README.md index 72fd2d7403..1fb51c9dfe 100644 --- a/README.md +++ b/README.md @@ -180,9 +180,9 @@ Post-training quantization can result in a fast and compact model, but may also ```python from torchao.quantization import quantize_ -from torchao.quantization.qat import FakeQuantizeConfig, IntXQuantizationAwareTrainingConfig -activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) -weight_config = FakeQuantizeConfig(torch.int4, group_size=32) +from torchao.quantization.qat import IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig +activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) +weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config), quantize_(my_model, qat_config) ``` diff --git a/docs/source/api_ref_qat.rst b/docs/source/api_ref_qat.rst index 046a1b74a4..b912e6ffef 100644 --- a/docs/source/api_ref_qat.rst +++ b/docs/source/api_ref_qat.rst @@ -24,7 +24,7 @@ Custom QAT APIs :toctree: generated/ :nosignatures: - FakeQuantizeConfig + IntxFakeQuantizeConfig FakeQuantizedLinear FakeQuantizedEmbedding FakeQuantizer diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 36765fb9b5..6ceeb0d795 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -30,8 +30,8 @@ from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE from torchao.quantization.granularity import PerGroup from torchao.quantization.qat import ( - FakeQuantizeConfig, FromIntXQuantizationAwareTrainingConfig, + IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig, ) from torchao.quantization.quant_api import ( @@ -393,7 +393,7 @@ def test_int8_dynamic_activation_intx_e2e( optimizer.step() # apply torchao quantized activations on top - activation_config = FakeQuantizeConfig( + activation_config = IntxFakeQuantizeConfig( torch.int8, granularity="per_token", mapping_type=config.act_mapping_type, diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index ee3ac50cbf..c83f64022b 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -32,15 +32,16 @@ ) from torchao.quantization.qat.api import ( ComposableQATQuantizer, - FakeQuantizeConfig, + FromIntXQuantizationAwareTrainingConfig, IntXQuantizationAwareTrainingConfig, - from_intx_quantization_aware_training, initialize_fake_quantizers, - intx_quantization_aware_training, ) from torchao.quantization.qat.embedding import ( FakeQuantizedEmbedding, ) +from torchao.quantization.qat.fake_quantize_config import ( + IntxFakeQuantizeConfig, +) from torchao.quantization.qat.fake_quantizer import ( FakeQuantizer, _Float8RowwiseActivationFakeQuantizer, @@ -829,26 +830,28 @@ def test_qat_4w_embedding(self): def test_fake_quantize_config_granularity(self): """ - Test initialization and property setting of `FakeQuantizeConfig`'s granularity. + Test initialization and property setting of `IntxFakeQuantizeConfig`'s granularity. """ # per token - per_token_config1 = FakeQuantizeConfig(torch.int8, PerToken()) - per_token_config2 = FakeQuantizeConfig(torch.int8, "per_token") + per_token_config1 = IntxFakeQuantizeConfig(torch.int8, PerToken()) + per_token_config2 = IntxFakeQuantizeConfig(torch.int8, "per_token") self.assertIsInstance(per_token_config1.granularity, PerToken) self.assertIsInstance(per_token_config2.granularity, PerToken) # per channel - per_channel_config1 = FakeQuantizeConfig(torch.int8, PerAxis(0)) - per_channel_config2 = FakeQuantizeConfig(torch.int8, "per_channel") + per_channel_config1 = IntxFakeQuantizeConfig(torch.int8, PerAxis(0)) + per_channel_config2 = IntxFakeQuantizeConfig(torch.int8, "per_channel") self.assertIsInstance(per_channel_config1.granularity, PerAxis) self.assertIsInstance(per_channel_config2.granularity, PerAxis) self.assertEqual(per_channel_config1.granularity.axis, 0) self.assertEqual(per_channel_config2.granularity.axis, 0) # per group - per_group_config1 = FakeQuantizeConfig(torch.int8, PerGroup(32)) - per_group_config2 = FakeQuantizeConfig(torch.int8, "per_group", group_size=32) - per_group_config3 = FakeQuantizeConfig(torch.int8, group_size=32) + per_group_config1 = IntxFakeQuantizeConfig(torch.int8, PerGroup(32)) + per_group_config2 = IntxFakeQuantizeConfig( + torch.int8, "per_group", group_size=32 + ) + per_group_config3 = IntxFakeQuantizeConfig(torch.int8, group_size=32) self.assertIsInstance(per_group_config1.granularity, PerGroup) self.assertIsInstance(per_group_config2.granularity, PerGroup) self.assertIsInstance(per_group_config3.granularity, PerGroup) @@ -869,48 +872,48 @@ def test_fake_quantize_config_granularity(self): def test_fake_quantize_config_granularity_error_cases(self): """ - Test incorrect settings of `FakeQuantizeConfig`'s granularity. + Test incorrect settings of `IntxFakeQuantizeConfig`'s granularity. """ # no granularity provided with self.assertRaisesRegex( ValueError, "`granularity` or `group_size` must be set" ): - FakeQuantizeConfig(torch.int8) + IntxFakeQuantizeConfig(torch.int8) # group_size with conflicting granularity msg = "`group_size` conflicts with granularity" with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int8, PerToken(), group_size=32) + IntxFakeQuantizeConfig(torch.int8, PerToken(), group_size=32) with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int8, PerGroup(64), group_size=32) + IntxFakeQuantizeConfig(torch.int8, PerGroup(64), group_size=32) with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int8, "per_token", group_size=32) + IntxFakeQuantizeConfig(torch.int8, "per_token", group_size=32) # 'per_group' but no group_size msg = "Granularity was 'per_group' but no `group_size` was set" with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int8, "per_group") + IntxFakeQuantizeConfig(torch.int8, "per_group") # not supported with self.assertRaisesRegex(ValueError, "not supported"): - FakeQuantizeConfig(torch.int8, PerRow()) + IntxFakeQuantizeConfig(torch.int8, PerRow()) with self.assertRaisesRegex(ValueError, "Only axis=0 is supported"): - FakeQuantizeConfig(torch.int8, PerAxis(1)) + IntxFakeQuantizeConfig(torch.int8, PerAxis(1)) with self.assertRaisesRegex(ValueError, "Unexpected granularity"): - FakeQuantizeConfig(torch.int8, "blah") + IntxFakeQuantizeConfig(torch.int8, "blah") with self.assertRaisesRegex(ValueError, "unexpected type"): - FakeQuantizeConfig(torch.int8, 1234) + IntxFakeQuantizeConfig(torch.int8, 1234) def test_fake_quantize_config_mapping_type(self): """ - Test initialization and property setting of `FakeQuantizeConfig`'s mapping type. + Test initialization and property setting of `IntxFakeQuantizeConfig`'s mapping type. """ # symmetric - symmetric_config1 = FakeQuantizeConfig(torch.int8, "per_token") - symmetric_config2 = FakeQuantizeConfig( + symmetric_config1 = IntxFakeQuantizeConfig(torch.int8, "per_token") + symmetric_config2 = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=True ) - symmetric_config3 = FakeQuantizeConfig( + symmetric_config3 = IntxFakeQuantizeConfig( torch.int8, "per_token", MappingType.SYMMETRIC ) self.assertEqual(symmetric_config1.mapping_type, MappingType.SYMMETRIC) @@ -921,10 +924,10 @@ def test_fake_quantize_config_mapping_type(self): self.assertTrue(symmetric_config3.is_symmetric) # asymmetric - asymmetric_config1 = FakeQuantizeConfig( + asymmetric_config1 = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False ) - asymmetric_config2 = FakeQuantizeConfig( + asymmetric_config2 = IntxFakeQuantizeConfig( torch.int8, "per_token", MappingType.ASYMMETRIC ) self.assertEqual(asymmetric_config1.mapping_type, MappingType.ASYMMETRIC) @@ -940,60 +943,60 @@ def test_fake_quantize_config_mapping_type(self): # bad config1: both mapping_type and is_symmetric are set msg = "Cannot set both `mapping_type` and `is_symmetric`" with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig( + IntxFakeQuantizeConfig( torch.int8, "per_token", MappingType.SYMMETRIC, is_symmetric=False ) # bad config2: not supported with self.assertRaisesRegex(ValueError, "not supported"): - FakeQuantizeConfig( + IntxFakeQuantizeConfig( torch.int8, "per_token", MappingType.SYMMETRIC_NO_CLIPPING_ERR ) def test_fake_quantize_config_dtype(self): """ - Test that unsupported dtypes are caught in `FakeQuantizeConfig`. + Test that unsupported dtypes are caught in `IntxFakeQuantizeConfig`. """ msg = "Unsupported dtype" with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int16, "per_token") + IntxFakeQuantizeConfig(torch.int16, "per_token") with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.int32, "per_token") + IntxFakeQuantizeConfig(torch.int32, "per_token") with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.bfloat16, "per_token") + IntxFakeQuantizeConfig(torch.bfloat16, "per_token") with self.assertRaisesRegex(ValueError, msg): - FakeQuantizeConfig(torch.float32, "per_token") + IntxFakeQuantizeConfig(torch.float32, "per_token") # OK if TORCH_VERSION_AT_LEAST_2_3: - FakeQuantizeConfig(torch.uint1, "per_token") - FakeQuantizeConfig(torch.uint2, "per_token") - FakeQuantizeConfig(torch.uint3, "per_token") - FakeQuantizeConfig(torch.uint4, "per_token") - FakeQuantizeConfig(torch.uint5, "per_token") - FakeQuantizeConfig(torch.uint6, "per_token") - FakeQuantizeConfig(torch.uint7, "per_token") - FakeQuantizeConfig(torch.uint8, "per_token") - FakeQuantizeConfig(TorchAODType.INT1, "per_token") - FakeQuantizeConfig(TorchAODType.INT2, "per_token") - FakeQuantizeConfig(TorchAODType.INT3, "per_token") - FakeQuantizeConfig(TorchAODType.INT4, "per_token") - FakeQuantizeConfig(TorchAODType.INT5, "per_token") - FakeQuantizeConfig(TorchAODType.INT6, "per_token") - FakeQuantizeConfig(TorchAODType.INT7, "per_token") - FakeQuantizeConfig(torch.int8, "per_token") + IntxFakeQuantizeConfig(torch.uint1, "per_token") + IntxFakeQuantizeConfig(torch.uint2, "per_token") + IntxFakeQuantizeConfig(torch.uint3, "per_token") + IntxFakeQuantizeConfig(torch.uint4, "per_token") + IntxFakeQuantizeConfig(torch.uint5, "per_token") + IntxFakeQuantizeConfig(torch.uint6, "per_token") + IntxFakeQuantizeConfig(torch.uint7, "per_token") + IntxFakeQuantizeConfig(torch.uint8, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT1, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT2, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT3, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT4, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT5, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT6, "per_token") + IntxFakeQuantizeConfig(TorchAODType.INT7, "per_token") + IntxFakeQuantizeConfig(torch.int8, "per_token") def test_fake_quantize_config_dynamic_and_range_learning(self): """ Test that `is_dynamic` and `range_learning` cannot both be set. """ - FakeQuantizeConfig( + IntxFakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=True, range_learning=False ) - FakeQuantizeConfig( + IntxFakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=False, range_learning=True ) with self.assertRaisesRegex(ValueError, "not compatible"): - FakeQuantizeConfig( + IntxFakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=True, range_learning=True ) @@ -1010,10 +1013,12 @@ def test_fake_quantized_linear_8da4w(self): 256, 688, bias=False, - activation_config=FakeQuantizeConfig( + activation_config=IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False ), - weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size), + weight_config=IntxFakeQuantizeConfig( + TorchAODType.INT4, group_size=group_size + ), ) def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: @@ -1059,7 +1064,7 @@ def test_fake_quantized_linear_4w(self): Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`. """ group_size = 128 - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( dtype=torch.uint4, group_size=group_size, is_symmetric=False, @@ -1172,7 +1177,9 @@ def test_fake_quantized_embedding_4w(self): fq_embedding = FakeQuantizedEmbedding( num_embeddings, embedding_dim, - weight_config=FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size), + weight_config=IntxFakeQuantizeConfig( + TorchAODType.INT4, group_size=group_size + ), ) def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: @@ -1258,7 +1265,7 @@ def test_quantize_api_standalone(self): """ Test that the following: - quantize_(model, intx_quantization_aware_training(...)) + quantize_(model, IntXQuantizationAwareTrainingConfig(...)) can produce the same results as `ComposableQATQuantizer`. """ @@ -1283,19 +1290,19 @@ def test_quantize_api_standalone(self): baseline_model = baseline_quantizer.prepare(baseline_model) # quantize_ API - activation_config = FakeQuantizeConfig( + activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False, ) - weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) quantize_( m, - intx_quantization_aware_training(activation_config, weight_config), + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), ) quantize_( m, - intx_quantization_aware_training(weight_config=weight_config), + IntXQuantizationAwareTrainingConfig(weight_config=weight_config), filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), ) @@ -1315,7 +1322,7 @@ def test_quantize_api_errors(self): Test that we throw exceptions with helpful error messages if `quantize_` runs into unexpected configurations. """ - my_config = FakeQuantizeConfig(torch.int8, group_size=32) + my_config = IntxFakeQuantizeConfig(torch.int8, group_size=32) m = M3() # Embedding currently only supports weight-only quantization @@ -1324,7 +1331,7 @@ def test_quantize_api_errors(self): ): quantize_( m, - intx_quantization_aware_training(my_config, my_config), + IntXQuantizationAwareTrainingConfig(my_config, my_config), lambda m, _: isinstance(m, torch.nn.Embedding), ) @@ -1332,7 +1339,7 @@ def test_quantize_api_errors(self): with self.assertRaisesRegex(ValueError, "does not have QAT support"): quantize_( m, - intx_quantization_aware_training(my_config, my_config), + IntXQuantizationAwareTrainingConfig(my_config, my_config), lambda m, _: isinstance(m, torch.nn.ReLU), ) @@ -1343,8 +1350,8 @@ def test_quantize_api_convert_path(self): """ Test that the following: - quantize_(model, intx_quantization_aware_training(...)) - quantize_(model, from_intx_quantization_aware_training(...)) + quantize_(model, IntXQuantizationAwareTrainingConfig(...)) + quantize_(model, FromIntXQuantizationAwareTrainingConfig(...)) quantize_(model, int8_dynamic_activation_int4_weight()) can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert. @@ -1363,15 +1370,15 @@ def test_quantize_api_convert_path(self): baseline_model = baseline_quantizer.prepare(baseline_model) # quantize_ prepare - activation_config = FakeQuantizeConfig( + activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False, ) - weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) quantize_( m, - intx_quantization_aware_training(activation_config, weight_config), + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), ) # Compare prepared values @@ -1386,7 +1393,7 @@ def test_quantize_api_convert_path(self): baseline_model = baseline_quantizer.convert(baseline_model) # quantize_ convert - quantize_(m, from_intx_quantization_aware_training()) + quantize_(m, FromIntXQuantizationAwareTrainingConfig()) quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size)) # Compare converted values @@ -1402,11 +1409,11 @@ def test_quantize_api_convert_path(self): ) def test_fake_quantize_config_torch_intx(self): """ - Test that `FakeQuantizeConfig` works with torch.intx. + Test that `IntxFakeQuantizeConfig` works with torch.intx. """ group_size = 16 - config1 = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) - config2 = FakeQuantizeConfig(torch.int4, group_size=group_size) + config1 = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + config2 = IntxFakeQuantizeConfig(torch.int4, group_size=group_size) linear1 = FakeQuantizedLinear(32, 64, weight_config=config1) linear2 = FakeQuantizedLinear(32, 64, weight_config=config2) linear2.weight = linear1.weight @@ -1424,7 +1431,7 @@ def test_fake_quantizer_repr(self): """ Test that `repr(FakeQuantizer(config))` exposes useful config details. """ - config = FakeQuantizeConfig(torch.int4, group_size=128) + config = IntxFakeQuantizeConfig(torch.int4, group_size=128) fake_quantizer = FakeQuantizer(config) fake_quantizer_repr = repr(fake_quantizer) self.assertTrue("dtype=torch.int4" in fake_quantizer_repr) @@ -1440,13 +1447,13 @@ def test_qat_linear_bias(self): Test that QAT supports linear bias. """ m = ModelWithLinearBias() - activation_config = FakeQuantizeConfig( + activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False ) - weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=32) + weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=32) quantize_( m, - intx_quantization_aware_training(activation_config, weight_config), + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), ) example_inputs = m.example_inputs() m(*example_inputs) @@ -1465,7 +1472,7 @@ def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype): torch.manual_seed(self.SEED) x = torch.randn(1, 235, 2048).to(dtype) - config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) fake_quantizer = FakeQuantizer(config) fake_quantizer_out = fake_quantizer(x) baseline_out = per_token_dynamic_quant(x) @@ -1518,7 +1525,7 @@ def test_qat_8da4w_prepare_vs_convert(self, dtype: torch.dtype): ) def test_fake_quantize_config_eps(self): """ - Test that users can set arbitrary eps value in `FakeQuantizeConfig`. + Test that users can set arbitrary eps value in `IntxFakeQuantizeConfig`. """ eps = 0.00123 x = torch.randn(2, 3).to(torch.float32) @@ -1532,7 +1539,7 @@ def test_fake_quantize_config_eps(self): eps=eps, ) expected_out = _fake_quantize_per_token(x, scale, zp, -128, 127) - config = FakeQuantizeConfig( + config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False, @@ -1598,7 +1605,7 @@ def test_fake_quantizer_range_learning(self): """ Test that range learning requires `FakeQuantizer`s to be initialized correctly. """ - config = FakeQuantizeConfig( + config = IntxFakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=False, @@ -1636,7 +1643,7 @@ def test_qat_range_learning(self): """ Test end-to-end QAT flow with range learning. """ - config = FakeQuantizeConfig( + config = IntxFakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=False, diff --git a/torchao/experimental/tests/test_embedding_xbit_quantizer.py b/torchao/experimental/tests/test_embedding_xbit_quantizer.py index 442612410e..1a87245ad4 100644 --- a/torchao/experimental/tests/test_embedding_xbit_quantizer.py +++ b/torchao/experimental/tests/test_embedding_xbit_quantizer.py @@ -21,9 +21,9 @@ ) from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.qat import ( - FakeQuantizeConfig, FromIntXQuantizationAwareTrainingConfig, Int4WeightOnlyEmbeddingQATQuantizer, + IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig, ) from torchao.quantization.quant_api import ( @@ -282,7 +282,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( ) embedding_filter = lambda m, fqn: isinstance(m, torch.nn.Embedding) - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( weight_dtype, group_size=group_size, is_symmetric=is_symmetric, diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py index 5cba538068..d8fedd7745 100644 --- a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -16,9 +16,9 @@ from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.qat import ( - FakeQuantizeConfig, FromIntXQuantizationAwareTrainingConfig, Int8DynActInt4WeightQATQuantizer, + IntxFakeQuantizeConfig, IntXQuantizationAwareTrainingConfig, ) from torchao.quantization.quant_api import ( @@ -538,12 +538,12 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig( model = model.to(model_dtype) activations = activations.to(model_dtype) - activation_config = FakeQuantizeConfig( + activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=is_act_symmetric, ) - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( weight_dtype, group_size=group_size, is_symmetric=is_symmetric, diff --git a/torchao/quantization/qat/README.md b/torchao/quantization/qat/README.md index 6395952ab5..777181b67e 100644 --- a/torchao/quantization/qat/README.md +++ b/torchao/quantization/qat/README.md @@ -71,7 +71,7 @@ def train_loop(m: torch.nn.Module): The recommended way to run QAT in torchao is through the `quantize_` API: 1. **Prepare:** specify how weights and/or activations are to be quantized through -[`FakeQuantizeConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.FakeQuantizeConfig.html#torchao.quantization.qat.FakeQuantizeConfig) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.IntXQuantizationAwareTrainingConfig.html#torchao.quantization.qat.IntXQuantizationAwareTrainingConfig) +[`IntxFakeQuantizeConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.IntxFakeQuantizeConfig.html#torchao.quantization.qat.IntxFakeQuantizeConfig) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.qat.IntXQuantizationAwareTrainingConfig.html#torchao.quantization.qat.IntXQuantizationAwareTrainingConfig) 2. **Convert:** quantize the model using the standard post-training quantization (PTQ) functions such as [`Int8DynamicActivationInt4WeightConfig`](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int8DynamicActivationInt4WeightConfig.html#torchao.quantization.Int8DynamicActivationInt4WeightConfig) @@ -84,7 +84,7 @@ from torchao.quantization import ( Int8DynamicActivationInt4WeightConfig, ) from torchao.quantization.qat import ( - FakeQuantizeConfig, + IntxFakeQuantizeConfig, FromIntXQuantizationAwareTrainingConfig, IntXQuantizationAwareTrainingConfig, ) @@ -92,8 +92,8 @@ model = get_model() # prepare: insert fake quantization ops # swaps `torch.nn.Linear` with `FakeQuantizedLinear` -activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) -weight_config = FakeQuantizeConfig(torch.int4, group_size=32) +activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) +weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) quantize_( model, IntXQuantizationAwareTrainingConfig(activation_config, weight_config), @@ -116,8 +116,8 @@ the following with a filter function during the prepare step: ``` # first apply linear transformation to the model as above -activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) -weight_config = FakeQuantizeConfig(torch.int4, group_size=32) +activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) +weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) quantize_( model, IntXQuantizationAwareTrainingConfig(activation_config, weight_config), diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 72cecfd254..1035cd8a38 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -1,6 +1,5 @@ from .api import ( ComposableQATQuantizer, - FakeQuantizeConfig, FromIntXQuantizationAwareTrainingConfig, IntXQuantizationAwareTrainingConfig, from_intx_quantization_aware_training, @@ -11,6 +10,11 @@ FakeQuantizedEmbedding, Int4WeightOnlyEmbeddingQATQuantizer, ) +from .fake_quantize_config import ( + FakeQuantizeConfig, + FakeQuantizeConfigBase, + IntxFakeQuantizeConfig, +) from .fake_quantizer import FakeQuantizer from .linear import ( FakeQuantizedLinear, @@ -21,7 +25,7 @@ __all__ = [ "ComposableQATQuantizer", - "FakeQuantizeConfig", + "FakeQuantizeConfigBase", "FakeQuantizedLinear", "FakeQuantizedEmbedding", "FakeQuantizer", @@ -30,8 +34,11 @@ "Int4WeightOnlyEmbeddingQATQuantizer", "Int4WeightOnlyQATQuantizer", "Int8DynActInt4WeightQATQuantizer", + "IntxFakeQuantizeConfig", "IntXQuantizationAwareTrainingConfig", "initialize_fake_quantizers", - "intx_quantization_aware_training", + # for BC + "FakeQuantizeConfig", "from_intx_quantization_aware_training", + "intx_quantization_aware_training", ] diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index b7df56409f..22607269c8 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -5,252 +5,20 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple import torch from torchao.core.config import AOBaseConfig -from torchao.quantization.granularity import ( - Granularity, - PerAxis, - PerGroup, - PerToken, -) -from torchao.quantization.quant_primitives import ( - _SUB_BYTE_INT_BOUNDS, - _SUB_BYTE_UINT_BOUNDS, - MappingType, - TorchAODType, - ZeroPointDomain, -) from torchao.quantization.transform_module import ( register_quantize_module_handler, ) from torchao.quantization.unified import TwoStepQuantizer - -@dataclass -class FakeQuantizeConfig: - """ - Config for how to fake quantize weights or activations. - - Args: - dtype: dtype to simulate during fake quantization, e.g. torch.int8. - For PyTorch versions older than 2.6, you may use `TorchAODType` to represent - torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4. - granularity: granularity of scales and zero points, e.g. PerGroup(32). - We also support the following strings: - 1) 'per_token': equivalent to PerToken() - 2) 'per_channel': equivalent to PerAxis(0) - 3) 'per_group': equivalent to PerGroup(group_size), must be combined - with separate `group_size` kwarg, Alternatively, just set the - `group_size` kwarg and leave this field empty. - mapping_type: whether to use symmetric (default) or asymmetric quantization - Alternatively, set `is_symmetric` (bool) and leave this field empty. - scale_precision: scale dtype (default torch.fp32) - zero_point_precision: zero point dtype (default torch.int32) - zero_point_domain: whether zero point is in integer (default) or float domain - is_dynamic: whether to use dynamic (default) or static scale and zero points - range_learning (prototype): whether to learn scale and zero points during training - (default false), not compatible with `is_dynamic`. - - Keyword args: - group_size: size of each group in per group fake quantization, - can be set instead of `granularity` - is_symmetric: whether to use symmetric or asymmetric quantization, - can be set instead of `mapping_type` - - Example usage:: - - # Per token asymmetric quantization - FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) - FakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC) - - # Per channel symmetric quantization - FakeQuantizeConfig(torch.int4, "per_channel") - FakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True) - FakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC) - - # Per group symmetric quantization - FakeQuantizeConfig(torch.int4, group_size=32) - FakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True) - FakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True) - FakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC) - """ - - dtype: Union[torch.dtype, TorchAODType] - granularity: Granularity - mapping_type: MappingType - scale_precision: torch.dtype - zero_point_precision: torch.dtype - zero_point_domain: ZeroPointDomain - is_dynamic: bool = True - range_learning: bool = False - eps: Optional[float] = None - - def __init__( - self, - dtype: Union[torch.dtype, TorchAODType], - granularity: Union[Granularity, str, None] = None, - mapping_type: Optional[MappingType] = None, - scale_precision: torch.dtype = torch.float32, - zero_point_precision: torch.dtype = torch.int32, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - is_dynamic: bool = True, - range_learning: bool = False, - eps: Optional[float] = None, - *, - group_size: Optional[int] = None, - is_symmetric: Optional[bool] = None, - ): - if zero_point_domain is None: - raise ValueError("Please use ZeroPointDomain.NONE instead of None") - self.dtype = dtype - self.granularity = self._get_granularity(granularity, group_size) - self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric) - self.scale_precision = scale_precision - self.zero_point_precision = zero_point_precision - self.zero_point_domain = zero_point_domain - self.is_dynamic = is_dynamic - self.range_learning = range_learning - self.eps = eps - - # Validate dtype - all_dtypes = [torch.int8, torch.uint8] - all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys())) - all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys())) - if dtype not in all_dtypes: - raise ValueError( - "Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes) - ) - - # Dynamic is not compatible with range learning - if is_dynamic and range_learning: - raise ValueError("`is_dynamic` is not compatible with `range_learning`") - - def _get_granularity( - self, - granularity: Union[Granularity, str, None], - group_size: Optional[int], - ) -> Granularity: - """ - Parse the `Granularity` represented in the args. - - Granularity can be specified in one of three ways: - 1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size) - 2) str: one of 'per_token', 'per_channel', and 'per_group' - 3) None: `group_size` must be set instead, represents per group granularity - """ - # If group_size is set, then granularity must be either "per_group" or None - if ( - group_size is not None - and granularity != "per_group" - and granularity is not None - ): - raise ValueError( - "`group_size` conflicts with granularity '%s'" % granularity - ) - - # Case 1: Granularity object - if isinstance(granularity, Granularity): - if not isinstance(granularity, (PerToken, PerAxis, PerGroup)): - raise ValueError("Granularity '%s' is not supported" % granularity) - if isinstance(granularity, PerAxis) and granularity.axis != 0: - raise ValueError("Only axis=0 is supported for PerAxis granularity") - return granularity - - # Case 2: str granularity - if granularity == "per_token": - return PerToken() - elif granularity == "per_channel": - return PerAxis(axis=0) - elif granularity == "per_group": - if group_size is None: - raise ValueError( - "Granularity was 'per_group' but no `group_size` was set" - ) - return PerGroup(group_size) - elif isinstance(granularity, str): - raise ValueError( - "Unexpected granularity: '%s', must be one of %s" - % (granularity, ["per_token", "per_channel", "per_group"]) - ) - - # Case 3: None granularity + group_size was specified - if granularity is not None: - raise ValueError( - "Granularity '%s' has unexpected type %s" - % (granularity, type(granularity)) - ) - if group_size is None: - raise ValueError( - "At least one of `granularity` or `group_size` must be set" - ) - return PerGroup(group_size) - - def _get_mapping_type( - self, - mapping_type: Optional[MappingType], - is_symmetric: Optional[bool], - ) -> MappingType: - """ - Parse the `MappingType` represented in the args. - - Mapping type can be specified in one of two ways: - 1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC - 2): is_symmetric bool - """ - if mapping_type is not None and is_symmetric is not None: - raise ValueError("Cannot set both `mapping_type` and `is_symmetric`") - - # Case 0: Default to symmetric - if mapping_type is None and is_symmetric is None: - return MappingType.SYMMETRIC - - # Case 1: MappingType object - if mapping_type is not None: - if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]: - raise ValueError("MappingType '%s' is not supported" % mapping_type) - return mapping_type - - # Case 2: is_symmetric flag - assert is_symmetric is not None - if is_symmetric: - return MappingType.SYMMETRIC - else: - return MappingType.ASYMMETRIC - - @property - def group_size(self) -> int: - """ - If this is per group granularity, return the group size. - Otherwise, throw an error. - """ - if isinstance(self.granularity, PerGroup): - return self.granularity.group_size - else: - raise ValueError( - "`group_size` is undefined for %s granularity" % self.granularity - ) - - @property - def is_symmetric(self) -> bool: - """ - Return True if mapping type is symmetric, else False (asymmetric). - """ - return self.mapping_type == MappingType.SYMMETRIC - - def __setattr__(self, name: str, value: Any): - """ - Support setting `group_size` and `is_symmetric`. - """ - if name == "group_size": - super().__setattr__("granularity", PerGroup(value)) - elif name == "is_symmetric": - mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC - super().__setattr__("mapping_type", mapping_type) - else: - super().__setattr__(name, value) +from .fake_quantize_config import ( + FakeQuantizeConfig, # noqa: F401, for BC + FakeQuantizeConfigBase, +) @dataclass @@ -262,11 +30,11 @@ class IntXQuantizationAwareTrainingConfig(AOBaseConfig): Example usage:: from torchao.quantization import quantize_ - from torchao.quantization.qat import FakeQuantizeConfig - activation_config = FakeQuantizeConfig( + from torchao.quantization.qat import IntxFakeQuantizeConfig + activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False, ) - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( torch.int4, group_size=32, is_symmetric=True, ) quantize_( @@ -280,8 +48,8 @@ class IntXQuantizationAwareTrainingConfig(AOBaseConfig): ValueError as these are not supported. """ - activation_config: Optional[FakeQuantizeConfig] = None - weight_config: Optional[FakeQuantizeConfig] = None + activation_config: Optional[FakeQuantizeConfigBase] = None + weight_config: Optional[FakeQuantizeConfigBase] = None # for BC diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index aec23712ed..778ba2b83c 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -13,7 +13,10 @@ from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.utils import get_group_qparams_symmetric -from .api import FakeQuantizeConfig +from .fake_quantize_config import ( + FakeQuantizeConfigBase, + IntxFakeQuantizeConfig, +) from .fake_quantizer import FakeQuantizer from .utils import ( _get_qmin_qmax, @@ -29,7 +32,7 @@ class FakeQuantizedEmbedding(torch.nn.Embedding): Example usage:: - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( dtype=torch.int4, group_size=8, symmetric=True, @@ -47,7 +50,7 @@ def __init__( norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, - weight_config: Optional[FakeQuantizeConfig] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, *args, **kwargs, ) -> None: @@ -105,7 +108,7 @@ def to_embedding(self) -> torch.nn.Embedding: def from_embedding( cls, mod: torch.nn.Embedding, - weight_config: Optional[FakeQuantizeConfig] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, ): new_embedding = FakeQuantizedEmbedding( mod.num_embeddings, @@ -285,7 +288,7 @@ def __init__( *args, **kwargs, ): - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( dtype=TorchAODType.INT4, group_size=group_size, is_symmetric=True, diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py new file mode 100644 index 0000000000..c6ad0b39c3 --- /dev/null +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import abc +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch + +from torchao.quantization.granularity import ( + Granularity, + PerAxis, + PerGroup, + PerToken, +) +from torchao.quantization.quant_primitives import ( + _SUB_BYTE_INT_BOUNDS, + _SUB_BYTE_UINT_BOUNDS, + MappingType, + TorchAODType, + ZeroPointDomain, +) + + +@dataclass +class FakeQuantizeConfigBase(abc.ABC): + """ + Base class for representing fake quantization config. + """ + + pass + + +@dataclass +class IntxFakeQuantizeConfig(FakeQuantizeConfigBase): + """ + Config for how to fake quantize weights or activations. + + Args: + dtype: dtype to simulate during fake quantization, e.g. torch.int8. + For PyTorch versions older than 2.6, you may use `TorchAODType` to represent + torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4. + granularity: granularity of scales and zero points, e.g. PerGroup(32). + We also support the following strings: + 1) 'per_token': equivalent to PerToken() + 2) 'per_channel': equivalent to PerAxis(0) + 3) 'per_group': equivalent to PerGroup(group_size), must be combined + with separate `group_size` kwarg, Alternatively, just set the + `group_size` kwarg and leave this field empty. + mapping_type: whether to use symmetric (default) or asymmetric quantization + Alternatively, set `is_symmetric` (bool) and leave this field empty. + scale_precision: scale dtype (default torch.fp32) + zero_point_precision: zero point dtype (default torch.int32) + zero_point_domain: whether zero point is in integer (default) or float domain + is_dynamic: whether to use dynamic (default) or static scale and zero points + range_learning (prototype): whether to learn scale and zero points during training + (default false), not compatible with `is_dynamic`. + + Keyword args: + group_size: size of each group in per group fake quantization, + can be set instead of `granularity` + is_symmetric: whether to use symmetric or asymmetric quantization, + can be set instead of `mapping_type` + + Example usage:: + + # Per token asymmetric quantization + IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) + IntxFakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC) + + # Per channel symmetric quantization + IntxFakeQuantizeConfig(torch.int4, "per_channel") + IntxFakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True) + IntxFakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC) + + # Per group symmetric quantization + IntxFakeQuantizeConfig(torch.int4, group_size=32) + IntxFakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True) + IntxFakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True) + IntxFakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC) + """ + + dtype: Union[torch.dtype, TorchAODType] + granularity: Granularity + mapping_type: MappingType + scale_precision: torch.dtype + zero_point_precision: torch.dtype + zero_point_domain: ZeroPointDomain + is_dynamic: bool = True + range_learning: bool = False + eps: Optional[float] = None + + def __init__( + self, + dtype: Union[torch.dtype, TorchAODType], + granularity: Union[Granularity, str, None] = None, + mapping_type: Optional[MappingType] = None, + scale_precision: torch.dtype = torch.float32, + zero_point_precision: torch.dtype = torch.int32, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + is_dynamic: bool = True, + range_learning: bool = False, + eps: Optional[float] = None, + *, + group_size: Optional[int] = None, + is_symmetric: Optional[bool] = None, + ): + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + self.dtype = dtype + self.granularity = self._get_granularity(granularity, group_size) + self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric) + self.scale_precision = scale_precision + self.zero_point_precision = zero_point_precision + self.zero_point_domain = zero_point_domain + self.is_dynamic = is_dynamic + self.range_learning = range_learning + self.eps = eps + + # Validate dtype + all_dtypes = [torch.int8, torch.uint8] + all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys())) + all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys())) + if dtype not in all_dtypes: + raise ValueError( + "Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes) + ) + + # Dynamic is not compatible with range learning + if is_dynamic and range_learning: + raise ValueError("`is_dynamic` is not compatible with `range_learning`") + + def _get_granularity( + self, + granularity: Union[Granularity, str, None], + group_size: Optional[int], + ) -> Granularity: + """ + Parse the `Granularity` represented in the args. + + Granularity can be specified in one of three ways: + 1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size) + 2) str: one of 'per_token', 'per_channel', and 'per_group' + 3) None: `group_size` must be set instead, represents per group granularity + """ + # If group_size is set, then granularity must be either "per_group" or None + if ( + group_size is not None + and granularity != "per_group" + and granularity is not None + ): + raise ValueError( + "`group_size` conflicts with granularity '%s'" % granularity + ) + + # Case 1: Granularity object + if isinstance(granularity, Granularity): + if not isinstance(granularity, (PerToken, PerAxis, PerGroup)): + raise ValueError("Granularity '%s' is not supported" % granularity) + if isinstance(granularity, PerAxis) and granularity.axis != 0: + raise ValueError("Only axis=0 is supported for PerAxis granularity") + return granularity + + # Case 2: str granularity + if granularity == "per_token": + return PerToken() + elif granularity == "per_channel": + return PerAxis(axis=0) + elif granularity == "per_group": + if group_size is None: + raise ValueError( + "Granularity was 'per_group' but no `group_size` was set" + ) + return PerGroup(group_size) + elif isinstance(granularity, str): + raise ValueError( + "Unexpected granularity: '%s', must be one of %s" + % (granularity, ["per_token", "per_channel", "per_group"]) + ) + + # Case 3: None granularity + group_size was specified + if granularity is not None: + raise ValueError( + "Granularity '%s' has unexpected type %s" + % (granularity, type(granularity)) + ) + if group_size is None: + raise ValueError( + "At least one of `granularity` or `group_size` must be set" + ) + return PerGroup(group_size) + + def _get_mapping_type( + self, + mapping_type: Optional[MappingType], + is_symmetric: Optional[bool], + ) -> MappingType: + """ + Parse the `MappingType` represented in the args. + + Mapping type can be specified in one of two ways: + 1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC + 2): is_symmetric bool + """ + if mapping_type is not None and is_symmetric is not None: + raise ValueError("Cannot set both `mapping_type` and `is_symmetric`") + + # Case 0: Default to symmetric + if mapping_type is None and is_symmetric is None: + return MappingType.SYMMETRIC + + # Case 1: MappingType object + if mapping_type is not None: + if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]: + raise ValueError("MappingType '%s' is not supported" % mapping_type) + return mapping_type + + # Case 2: is_symmetric flag + assert is_symmetric is not None + if is_symmetric: + return MappingType.SYMMETRIC + else: + return MappingType.ASYMMETRIC + + @property + def group_size(self) -> int: + """ + If this is per group granularity, return the group size. + Otherwise, throw an error. + """ + if isinstance(self.granularity, PerGroup): + return self.granularity.group_size + else: + raise ValueError( + "`group_size` is undefined for %s granularity" % self.granularity + ) + + @property + def is_symmetric(self) -> bool: + """ + Return True if mapping type is symmetric, else False (asymmetric). + """ + return self.mapping_type == MappingType.SYMMETRIC + + def __setattr__(self, name: str, value: Any): + """ + Support setting `group_size` and `is_symmetric`. + """ + if name == "group_size": + super().__setattr__("granularity", PerGroup(value)) + elif name == "is_symmetric": + mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC + super().__setattr__("mapping_type", mapping_type) + else: + super().__setattr__(name, value) + + +# For BC +FakeQuantizeConfig = IntxFakeQuantizeConfig diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index b7ad792dc1..3cb873f3ff 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -26,8 +26,9 @@ get_groupwise_affine_qparams, ) -from .api import ( - FakeQuantizeConfig, +from .fake_quantize_config import ( + FakeQuantizeConfigBase, + IntxFakeQuantizeConfig, ) from .utils import ( _fake_quantize_per_channel_group, @@ -41,7 +42,7 @@ class FakeQuantizer(torch.nn.Module): Generic module for applying fake quantization to a tensor, as specified in the config. """ - def __init__(self, config: FakeQuantizeConfig): + def __init__(self, config: FakeQuantizeConfigBase): super().__init__() self.config = config self.enabled = True @@ -61,6 +62,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.enabled: return x + if not isinstance(self.config, IntxFakeQuantizeConfig): + raise ValueError("Only IntxFakeQuantizeConfig is supported currently") + if ( self.config.range_learning and not self._initialized diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index 02b48fc5e3..c9c8f8ea5d 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -27,7 +27,10 @@ from torchao.quantization.utils import get_group_qparams_symmetric from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 -from .api import FakeQuantizeConfig +from .fake_quantize_config import ( + FakeQuantizeConfigBase, + IntxFakeQuantizeConfig, +) from .fake_quantizer import ( FakeQuantizer, _Float8RowwiseActivationFakeQuantizer, @@ -46,12 +49,12 @@ class FakeQuantizedLinear(torch.nn.Linear): Example usage:: - activation_config = FakeQuantizeConfig( + activation_config = IntxFakeQuantizeConfig( dtype=torch.int8, granularity="per_token", is_symmetric=False, ) - weight_config = FakeQuantizeConfig( + weight_config = IntxFakeQuantizeConfig( dtype=torch.int4, group_size=8, is_symmetric=True, @@ -67,8 +70,8 @@ def __init__( in_features: int, out_features: int, bias: bool = False, - activation_config: Optional[FakeQuantizeConfig] = None, - weight_config: Optional[FakeQuantizeConfig] = None, + activation_config: Optional[FakeQuantizeConfigBase] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, *args, **kwargs, ) -> None: @@ -127,8 +130,8 @@ def to_linear(self) -> torch.nn.Linear: def from_linear( cls, mod: torch.nn.Linear, - activation_config: Optional[FakeQuantizeConfig] = None, - weight_config: Optional[FakeQuantizeConfig] = None, + activation_config: Optional[FakeQuantizeConfigBase] = None, + weight_config: Optional[FakeQuantizeConfigBase] = None, ): new_linear = FakeQuantizedLinear( mod.in_features, @@ -179,10 +182,10 @@ class _LegacyQATQuantizer(TwoStepQuantizer): Base class for sharing common methods across legacy QAT quantizers. """ - def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return None - def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return None @@ -281,10 +284,10 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module): else: self._convert_qat_linear_8da4w(child) - def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return _get_8da4w_activation_config(self.activation_scales_precision) - def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return _get_8da4w_weight_config(self.groupsize, self.scales_precision) @@ -354,13 +357,15 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module): mod.disable_fake_quant() -def _get_8da4w_activation_config(qparams_precision: torch.dtype) -> FakeQuantizeConfig: +def _get_8da4w_activation_config( + qparams_precision: torch.dtype, +) -> IntxFakeQuantizeConfig: """ - Return the activation `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. + Return the activation `IntxFakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. """ # TODO: generalize this assert qparams_precision == torch.float32 - return FakeQuantizeConfig( + return IntxFakeQuantizeConfig( dtype=torch.int8, granularity="per_token", is_symmetric=False, @@ -374,11 +379,11 @@ def _get_8da4w_activation_config(qparams_precision: torch.dtype) -> FakeQuantize def _get_8da4w_weight_config( group_size: int, qparams_precision: torch.dtype, -) -> FakeQuantizeConfig: +) -> IntxFakeQuantizeConfig: """ - Return the weight `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. + Return the weight `IntxFakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. """ - return FakeQuantizeConfig( + return IntxFakeQuantizeConfig( dtype=TorchAODType.INT4, group_size=group_size, is_symmetric=True, @@ -482,7 +487,7 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): else: self._convert_qat_linear_4w(child) - def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return _get_4w_weight_config(self.groupsize, self.scales_precision) @@ -553,11 +558,11 @@ def disable_4w_fake_quant(mod: torch.nn.Module): def _get_4w_weight_config( group_size: int, qparams_precision: torch.dtype, -) -> FakeQuantizeConfig: +) -> IntxFakeQuantizeConfig: """ - Return the weight `FakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`. + Return the weight `IntxFakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`. """ - return FakeQuantizeConfig( + return IntxFakeQuantizeConfig( dtype=torch.uint4, group_size=group_size, is_symmetric=False, @@ -595,7 +600,7 @@ def __init__( weight_granularity = "per_group" else: weight_granularity = "per_channel" - self._weight_config = FakeQuantizeConfig( + self._weight_config = IntxFakeQuantizeConfig( dtype=torch.int4, granularity=weight_granularity, group_size=group_size, @@ -632,8 +637,8 @@ def convert( ) -> torch.nn.Module: raise NotImplementedError - def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: raise NotImplementedError("Float8 FakeQuantizeConfig does not exist yet") - def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: + def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return self.weight_config From cba74300cf9cf7267ebb7eb85e71c262370808e7 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 30 Jul 2025 15:58:09 -0700 Subject: [PATCH 2/4] Update on "[bc-breaking] Generalize FakeQuantizeConfig beyond intx" **Summary:** The existing `FakeQuantizeConfig` performs only intx quantization, but we plan to extend QAT to other dtypes such as fp8 and nvfp4 in the near future. This is the necessary refactor before that. Specifically: ``` # New abstract class FakeQuantizeConfigBase # Rename FakeQuantizeConfig -> IntxFakeQuantizeConfig ``` In the future, we will have other types of `FakeQuantizeConfigBase` for float dtypes that users can pass in instead of the existing Intx one. **BC-breaking notes:** For BC, we keep around the old names to reference the new ones. However, this commit is still BC-breaking in the sense that a few APIs now accept the abstract `FakeQuantizeConfigBase` instead. For the most part, this abstract class will be hidden from the user. Before: ``` activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) ``` After: ``` activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) ``` **Test Plan:** python test/quantization/test_qat.py [ghstack-poisoned] From 106e14627cef2e4a06fd0da3dbe0ffd2d2bf2afa Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 30 Jul 2025 16:00:25 -0700 Subject: [PATCH 3/4] Update on "[bc-breaking] Generalize FakeQuantizeConfig beyond intx" **Summary:** The existing `FakeQuantizeConfig` performs only intx quantization, but we plan to extend QAT to other dtypes such as fp8 and nvfp4 in the near future. This is the necessary refactor before that. Specifically: ``` # New abstract class FakeQuantizeConfigBase # Rename FakeQuantizeConfig -> IntxFakeQuantizeConfig ``` In the future, we will have other types of `FakeQuantizeConfigBase` for float dtypes that users can pass in instead of the existing Intx one. **BC-breaking notes:** For BC, we keep around the old names to reference the new ones. However, this commit is still BC-breaking in the sense that a few APIs now accept the abstract `FakeQuantizeConfigBase` instead. For the most part, this abstract class will be hidden from the user. Before: ``` activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) ``` After: ``` activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) ``` **Test Plan:** python test/quantization/test_qat.py [ghstack-poisoned] From 8245cee7b6cf19faf8ca547f5097823ec2c13ee3 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Thu, 31 Jul 2025 14:06:43 -0700 Subject: [PATCH 4/4] Update on "[bc-breaking] Generalize FakeQuantizeConfig beyond intx" **Summary:** The existing `FakeQuantizeConfig` performs only intx quantization, but we plan to extend QAT to other dtypes such as fp8 and nvfp4 in the near future. This is the necessary refactor before that. Specifically: ``` # New abstract class FakeQuantizeConfigBase # Rename FakeQuantizeConfig -> IntxFakeQuantizeConfig ``` In the future, we will have other types of `FakeQuantizeConfigBase` for float dtypes that users can pass in instead of the existing Intx one. **BC-breaking notes:** For BC, we keep around the old names to reference the new ones. However, this commit is still BC-breaking in the sense that a few APIs now accept the abstract `FakeQuantizeConfigBase` instead. For the most part, this abstract class will be hidden from the user. Before: ``` activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) ``` After: ``` activation_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = IntxFakeQuantizeConfig(torch.int4, group_size=32) ``` **Test Plan:** python test/quantization/test_qat.py [ghstack-poisoned] --- torchao/quantization/qat/fake_quantize_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index c6ad0b39c3..7369c02148 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -25,7 +25,6 @@ ) -@dataclass class FakeQuantizeConfigBase(abc.ABC): """ Base class for representing fake quantization config.