Skip to content

Generalize FakeQuantizer beyond intx #2714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/api_ref_qat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ Custom QAT APIs
IntxFakeQuantizeConfig
FakeQuantizedLinear
FakeQuantizedEmbedding
FakeQuantizer
FakeQuantizerBase
IntxFakeQuantizer
linear.enable_linear_fake_quant
linear.disable_linear_fake_quant

Expand Down
20 changes: 11 additions & 9 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
IntxFakeQuantizeConfig,
)
from torchao.quantization.qat.fake_quantizer import (
FakeQuantizer,
IntxFakeQuantizer,
_Float8RowwiseActivationFakeQuantizer,
)
from torchao.quantization.qat.linear import (
Expand Down Expand Up @@ -1466,10 +1466,10 @@ def test_fake_quantize_config_torch_intx(self):
)
def test_fake_quantizer_repr(self):
"""
Test that `repr(FakeQuantizer(config))` exposes useful config details.
Test that `repr(IntxFakeQuantizer(config))` exposes useful config details.
"""
config = IntxFakeQuantizeConfig(torch.int4, group_size=128)
fake_quantizer = FakeQuantizer(config)
fake_quantizer = IntxFakeQuantizer(config)
fake_quantizer_repr = repr(fake_quantizer)
self.assertTrue("dtype=torch.int4" in fake_quantizer_repr)
self.assertTrue("group_size=128" in fake_quantizer_repr)
Expand Down Expand Up @@ -1500,15 +1500,15 @@ def test_qat_linear_bias(self):
def test_fake_quantize_per_token_vs_convert(self, dtype: torch.dtype):
"""
Test that the following produce the exact same numerics:
1. FakeQuantizer with asymmetric per_token config
1. IntxFakeQuantizer with asymmetric per_token config
2. torchao.quantization.utils.per_token_dynamic_quant
"""
from torchao.quantization.utils import per_token_dynamic_quant

torch.manual_seed(self.SEED)
x = torch.randn(1, 235, 2048).to(dtype)
config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
fake_quantizer = FakeQuantizer(config)
fake_quantizer = IntxFakeQuantizer(config)
fake_quantizer_out = fake_quantizer(x)
baseline_out = per_token_dynamic_quant(x)
torch.testing.assert_close(fake_quantizer_out, baseline_out, atol=0, rtol=0)
Expand Down Expand Up @@ -1580,7 +1580,7 @@ def test_fake_quantize_config_eps(self):
is_symmetric=False,
eps=eps,
)
fake_quantizer = FakeQuantizer(config)
fake_quantizer = IntxFakeQuantizer(config)
actual_out = fake_quantizer(x)
torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0)

Expand Down Expand Up @@ -1638,7 +1638,7 @@ def test_qat_8da4w_eps(self):
)
def test_fake_quantizer_range_learning(self):
"""
Test that range learning requires `FakeQuantizer`s to be initialized correctly.
Test that range learning requires `IntxFakeQuantizer`s to be initialized correctly.
"""
config = IntxFakeQuantizeConfig(
torch.int8,
Expand All @@ -1648,7 +1648,7 @@ def test_fake_quantizer_range_learning(self):
scale_precision=torch.float32,
zero_point_precision=torch.float32,
)
fake_quantizer = FakeQuantizer(config)
fake_quantizer = IntxFakeQuantizer(config)
example_inputs = (torch.randn(2, 3),)

# Not initialized, should fail
Expand Down Expand Up @@ -1770,7 +1770,7 @@ def test_qat_fp8a4w_quantizer(self):
self.assertIsInstance(
linear.activation_fake_quantizer, _Float8RowwiseActivationFakeQuantizer
)
self.assertIsInstance(linear.weight_fake_quantizer, FakeQuantizer)
self.assertIsInstance(linear.weight_fake_quantizer, IntxFakeQuantizer)
prev_weight = copy.deepcopy(m.linear1.weight)

# Simulate training
Expand Down Expand Up @@ -1854,6 +1854,7 @@ def test_qat_api_deprecation(self):
"""
from torchao.quantization.qat import (
FakeQuantizeConfig,
FakeQuantizer,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)
Expand All @@ -1868,6 +1869,7 @@ def test_qat_api_deprecation(self):
intx_quantization_aware_training: (),
from_intx_quantization_aware_training: (),
FakeQuantizeConfig: (torch.int8, "per_channel"),
FakeQuantizer: (IntxFakeQuantizeConfig(torch.int8, "per_channel"),),
}

with warnings.catch_warnings(record=True) as _warnings:
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/prototype/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torchao.quantization.qat.fake_quantizer import (
FakeQuantizer,
IntxFakeQuantizer as FakeQuantizer,
)

__all__ = [
Expand Down
10 changes: 8 additions & 2 deletions torchao/quantization/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
FakeQuantizeConfigBase,
IntxFakeQuantizeConfig,
)
from .fake_quantizer import FakeQuantizer
from .fake_quantizer import (
FakeQuantizer,
FakeQuantizerBase,
IntxFakeQuantizer,
)
from .linear import (
FakeQuantizedLinear,
Float8ActInt4WeightQATQuantizer,
Expand All @@ -29,8 +33,9 @@
"QATConfig",
"QATStep",
"FakeQuantizeConfigBase",
"FakeQuantizerBase",
"IntxFakeQuantizeConfig",
"FakeQuantizer",
"IntxFakeQuantizer",
"FakeQuantizedLinear",
"FakeQuantizedEmbedding",
# Prototype
Expand All @@ -42,6 +47,7 @@
"Int4WeightOnlyQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
# for BC
"FakeQuantizer",
"FakeQuantizeConfig",
"from_intx_quantization_aware_training",
"FromIntXQuantizationAwareTrainingConfig",
Expand Down
6 changes: 3 additions & 3 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,14 @@ def initialize_fake_quantizers(
) -> None:
"""
(Prototype) Initialize the scales and zero points on all
:class:`~torchao.quantization.qat.fake_quantizer.FakeQuantizer`
:class:`~torchao.quantization.qat.fake_quantizer.IntxFakeQuantizerBase`
in the model based on the provided example inputs.
"""
# avoid circular dependencies
from torchao.quantization.qat.fake_quantizer import FakeQuantizer
from torchao.quantization.qat.fake_quantizer import IntxFakeQuantizer

def _set_initialized(m: torch.nn.Module):
if isinstance(m, FakeQuantizer):
if isinstance(m, IntxFakeQuantizer):
m._initialized = True

model.apply(_set_initialized)
Expand Down
4 changes: 2 additions & 2 deletions torchao/quantization/qat/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
FakeQuantizeConfigBase,
IntxFakeQuantizeConfig,
)
from .fake_quantizer import FakeQuantizer
from .fake_quantizer import FakeQuantizerBase
from .utils import (
_get_qmin_qmax,
)
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
**kwargs,
)
if weight_config is not None:
self.weight_fake_quantizer = FakeQuantizer(weight_config)
self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config)
else:
self.weight_fake_quantizer = None

Expand Down
45 changes: 35 additions & 10 deletions torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,37 @@
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_Float8RowwiseFakeQuantize,
_log_deprecation_warning,
)


class FakeQuantizer(torch.nn.Module):
class FakeQuantizerBase(torch.nn.Module):
"""
Generic module for applying fake quantization to a tensor, as specified in the config.
"""

def __init__(self, config: FakeQuantizeConfigBase):
config: FakeQuantizeConfigBase

def __repr__(self) -> str:
"""
Return a human readable representation of this `FakeQuantizer` with config details.
"""
return "FakeQuantizer(%s)" % self.config

@staticmethod
def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":
if isinstance(config, IntxFakeQuantizeConfig):
return IntxFakeQuantizer(config)
else:
raise ValueError(f"Unknown config type: {config}")


class IntxFakeQuantizer(FakeQuantizerBase):
"""
Generic module for applying integer fake quantization to a tensor, as specified in the config.
"""

def __init__(self, config: IntxFakeQuantizeConfig):
super().__init__()
self.config = config
self.enabled = True
Expand All @@ -62,9 +84,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return x

if not isinstance(self.config, IntxFakeQuantizeConfig):
raise ValueError("Only IntxFakeQuantizeConfig is supported currently")

if (
self.config.range_learning
and not self._initialized
Expand Down Expand Up @@ -186,13 +205,19 @@ def _maybe_update_qparams_for_range_learning(self) -> None:
self.scale = torch.nn.Parameter(scale, requires_grad=True)
self.zero_point = torch.nn.Parameter(zero_point, requires_grad=True)

def __repr__(self) -> str:
"""
Return a human readable representation of this `FakeQuantizer` with config details.
"""
return "FakeQuantizer(%s)" % self.config

# For BC
class FakeQuantizer(IntxFakeQuantizer):
"""
(Deprecated) Please use :class:`~torchao.quantization.qat.IntxFakeQuantizer` instead.
"""

def __init__(self, config: FakeQuantizeConfigBase):
super().__init__(config)
_log_deprecation_warning(self)


# TODO: make this a FakeQuantizerBase
class _Float8RowwiseActivationFakeQuantizer(torch.nn.Module):
"""
Simple fake quantizer for float8 rowwise fake quantization, intended for activations only.
Expand Down
8 changes: 5 additions & 3 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
IntxFakeQuantizeConfig,
)
from .fake_quantizer import (
FakeQuantizer,
FakeQuantizerBase,
_Float8RowwiseActivationFakeQuantizer,
)
from .utils import (
Expand Down Expand Up @@ -84,7 +84,9 @@ def __init__(
)
# initialize activation fake quantizer
if activation_config is not None:
self.activation_fake_quantizer = FakeQuantizer(activation_config)
self.activation_fake_quantizer = FakeQuantizerBase.from_config(
activation_config
)
else:
self.activation_fake_quantizer = None

Expand All @@ -97,7 +99,7 @@ def __init__(
"in_features (%s) %% group_size (%s) must be == 0"
% (in_features, group_size)
)
self.weight_fake_quantizer = FakeQuantizer(weight_config)
self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config)
else:
self.weight_fake_quantizer = None

Expand Down
Loading