Skip to content

Commit 209760e

Browse files
committed
Add NVFP4 QAT
**Summary:** This commit adds a QAT flow for NVFP4, following the numerics in `NVFP4Tensor` closely but without the dtyping casting, swizzling, and the packing/unpacking. Users can call this flow as follows: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig qat_config = QATConfig( weight_config=NVFP4FakeQuantizeConfig(), step="prepare", ) quantize_(model, qat_config) ``` **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 ``` Initial benchmarks on fine-tuning Qwen3-1.7B on alpaca for 3 epochs: ``` # Without QAT | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.8322|± | N/A| | | |none |None |byte_perplexity|↓ | 1.7804|± | N/A| | | |none |None |word_perplexity|↓ |21.8611|± | N/A| # With QAT | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.8271|± | N/A| | | |none |None |byte_perplexity|↓ | 1.7741|± | N/A| | | |none |None |word_perplexity|↓ |21.4467|± | N/A| ``` ghstack-source-id: fb5c617 Pull Request resolved: #2666
1 parent bc2c83e commit 209760e

File tree

7 files changed

+194
-18
lines changed

7 files changed

+194
-18
lines changed

test/quantization/test_qat.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def __init__(self):
118118
self.sub = Sub()
119119
self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float)
120120

121-
def example_inputs(self):
122-
return (torch.randn(1, 512).to(torch.float),)
121+
def example_inputs(self, device: torch.device = None):
122+
return (torch.randn((1, 512), device=device).to(torch.float),)
123123

124124
def _get_all_weight_scales(self) -> List[torch.Tensor]:
125125
return [
@@ -1928,7 +1928,7 @@ def test_quantize_api_fp8_int4(self):
19281928
"""
19291929
self._test_quantize_api_against_ptq(
19301930
Float8DynamicActivationInt4WeightConfig(),
1931-
target_prepare_sqnr=15,
1931+
target_prepare_sqnr=12,
19321932
target_convert_sqnr=float("inf"),
19331933
)
19341934

@@ -1952,6 +1952,47 @@ def test_infer_fp8_int4_config(self):
19521952
self.assertEqual(weight_config.group_size, 128)
19531953
self.assertTrue(weight_config.is_symmetric)
19541954

1955+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
1956+
def test_quantize_api_nvfp4(self):
1957+
"""
1958+
Test the following:
1959+
quantize_(model, QATConfig(NVFP4InferenceConfig(), step="prepare"))
1960+
quantize_(model, QATConfig(NVFP4InferenceConfig(), step="convert"))
1961+
"""
1962+
from torchao.prototype.mx_formats import NVFP4InferenceConfig
1963+
1964+
self._test_quantize_api_against_ptq(
1965+
NVFP4InferenceConfig(),
1966+
target_prepare_sqnr=8,
1967+
target_convert_sqnr=float("inf"),
1968+
)
1969+
1970+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1971+
@parametrize("use_per_tensor_scale", [True, False])
1972+
def test_qat_nvfp4(self, use_per_tensor_scale: bool):
1973+
"""
1974+
Test QAT with `NVFP4FakeQuantizeConfig`.
1975+
"""
1976+
from torchao.prototype.qat import NVFP4FakeQuantizeConfig
1977+
1978+
torch.manual_seed(self.SEED)
1979+
m = M().cuda()
1980+
baseline_model = copy.deepcopy(m)
1981+
qat_config = QATConfig(
1982+
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
1983+
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
1984+
step="prepare",
1985+
)
1986+
quantize_(m, qat_config)
1987+
1988+
# Compare prepared values
1989+
torch.manual_seed(self.SEED)
1990+
x = m.example_inputs("cuda")
1991+
out = m(*x)
1992+
baseline_out = baseline_model(*x)
1993+
sqnr = compute_error(out, baseline_out).item()
1994+
self.assertGreater(sqnr, 24)
1995+
19551996

19561997
instantiate_parametrized_tests(TestQAT)
19571998

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -751,13 +751,37 @@ def nvfp4_quantize(
751751
AssertionError: If input dtype is not supported, tensor size is not
752752
divisible by block_size, tensor is not contiguous, or block_size != 16
753753
"""
754+
return _nvfp4_quantize(data_hp, block_size, per_tensor_scale)
755+
756+
757+
class _Float8Round(torch.autograd.Function):
758+
"""
759+
Cast a tensor to float8 and back to float32 with backward STE.
760+
"""
761+
762+
@staticmethod
763+
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
764+
return x.to(torch.float8_e4m3fn).to(torch.float32)
765+
766+
@staticmethod
767+
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
768+
return gy
769+
770+
771+
def _nvfp4_quantize(
772+
data_hp: torch.Tensor,
773+
block_size: int = 16,
774+
per_tensor_scale: Optional[torch.Tensor] = None,
775+
skip_dtype_cast_and_packing: bool = False,
776+
) -> tuple[torch.Tensor, torch.Tensor]:
754777
assert data_hp.dtype in (torch.bfloat16, torch.float), (
755778
f"{data_hp.dtype} not supported"
756779
)
757780
assert data_hp.size(-1) % block_size == 0, "K dim must be divisible by block_size"
758781
assert data_hp.is_contiguous(), "Only support contiguous data for now"
759782
assert block_size == 16, "NVFP4 requires block_size=16"
760783

784+
orig_dtype = data_hp.dtype
761785
orig_shape = data_hp.shape
762786
# Convert to float32 early for consistent precision with Triton implementation
763787
data_hp = data_hp.float().reshape(orig_shape[0], -1, block_size)
@@ -769,10 +793,8 @@ def nvfp4_quantize(
769793
out_scales = None
770794
if per_tensor_scale is None:
771795
# We are doing single level scaling
772-
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to(
773-
torch.float8_e4m3fn
774-
)
775-
block_scale_fp32 = block_scale_fp8.to(torch.float32)
796+
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX)
797+
block_scale_fp32 = _Float8Round.apply(block_scale_fp8)
776798
data_scaled = data_hp / block_scale_fp32.unsqueeze(-1)
777799
out_scales = block_scale_fp8
778800
else:
@@ -784,8 +806,8 @@ def nvfp4_quantize(
784806
scaled_block_scales = block_scale_fp32 / per_tensor_scale
785807
scaled_block_scales_fp8 = torch.clamp(
786808
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
787-
).to(torch.float8_e4m3fn)
788-
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
809+
)
810+
scaled_block_scales_fp32 = _Float8Round.apply(scaled_block_scales_fp8)
789811
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
790812
# To apply to data
791813
total_scale = per_tensor_scale * scaled_block_scales_fp32
@@ -794,8 +816,11 @@ def nvfp4_quantize(
794816

795817
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
796818
data_scaled = data_scaled.view(orig_shape)
797-
data_lp = f32_to_f4_unpacked(data_scaled)
798-
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
799-
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
800-
data_lp = pack_uint4(data_lp)
801-
return out_scales, data_lp
819+
if skip_dtype_cast_and_packing:
820+
return out_scales.to(torch.float32), data_scaled.to(orig_dtype)
821+
else:
822+
data_lp = f32_to_f4_unpacked(data_scaled)
823+
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
824+
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
825+
data_lp = pack_uint4(data_lp)
826+
return out_scales.to(torch.float8_e4m3fn), data_lp

torchao/prototype/qat/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Temporary location for prototype QAT features that will
2+
# eventually live in torchao/quantization/qat
3+
4+
from .nvfp4 import (
5+
NVFP4FakeQuantizeConfig,
6+
NVFP4FakeQuantizer,
7+
)
8+
9+
__all__ = [
10+
"NVFP4FakeQuantizeConfig",
11+
"NVFP4FakeQuantizer",
12+
]

torchao/prototype/qat/nvfp4.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
5+
from torchao.prototype.mx_formats.nvfp4_tensor import (
6+
_nvfp4_quantize,
7+
per_tensor_amax_to_scale,
8+
)
9+
from torchao.quantization.qat import (
10+
FakeQuantizeConfigBase,
11+
FakeQuantizerBase,
12+
)
13+
14+
15+
@dataclass
16+
class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
17+
"""
18+
Config for fake quantizing weights or activations to NVIDIA's NVFP4 format
19+
according to https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.
20+
21+
Fake quantization numerics follow `NVFP4Tensor` closely: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py.
22+
23+
Args:
24+
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
25+
after the initial fp8 (e4m3) block-wise scaling (default True)
26+
"""
27+
28+
use_per_tensor_scale: bool = True
29+
30+
31+
class NVFP4FakeQuantizer(FakeQuantizerBase):
32+
"""
33+
(Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
34+
"""
35+
36+
def __init__(self, config: NVFP4FakeQuantizeConfig):
37+
super().__init__()
38+
torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer")
39+
self.config = config
40+
41+
def forward(self, x: torch.Tensor) -> torch.Tensor:
42+
block_size = 16
43+
original_shape = x.shape
44+
if x.dim() == 3:
45+
x = x.view(-1, x.shape[-1])
46+
if self.config.use_per_tensor_scale:
47+
tensor_amax = torch.max(torch.abs(x))
48+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
49+
else:
50+
per_tensor_scale = None
51+
52+
# quantize
53+
scale, q = _nvfp4_quantize(
54+
x,
55+
block_size=block_size,
56+
per_tensor_scale=per_tensor_scale,
57+
skip_dtype_cast_and_packing=True,
58+
)
59+
if self.config.use_per_tensor_scale:
60+
scale = scale * per_tensor_scale
61+
assert q.dtype == x.dtype
62+
assert scale.dtype == torch.float32
63+
64+
# dequantize
65+
M, K = q.shape[0], q.shape[1]
66+
q = q.view(M, K // block_size, block_size)
67+
scale = scale.view(M, K // block_size, 1)
68+
dq = q * scale
69+
return dq.view(original_shape).to(x.dtype)

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,6 @@ def __post_init__(self):
320320
_log_deprecation_warning(self)
321321

322322

323-
# TODO: rewrite using registration API?
324323
def _infer_fake_quantize_configs(
325324
base_config: AOBaseConfig,
326325
) -> Tuple[Optional[FakeQuantizeConfigBase], Optional[FakeQuantizeConfigBase]]:
@@ -331,7 +330,15 @@ def _infer_fake_quantize_configs(
331330
332331
Return a 2-tuple of (activation_config, weight_config) for fake quantization.
333332
"""
333+
# TODO: rewrite using registration API so we don't need to import here
334334
# avoid circular imports
335+
from torchao.prototype.mx_formats import (
336+
NVFP4InferenceConfig,
337+
NVFP4MMConfig,
338+
)
339+
from torchao.prototype.qat import (
340+
NVFP4FakeQuantizeConfig,
341+
)
335342
from torchao.quantization import (
336343
Float8DynamicActivationFloat8WeightConfig,
337344
Float8DynamicActivationInt4WeightConfig,
@@ -385,6 +392,17 @@ def _infer_fake_quantize_configs(
385392
group_size=128,
386393
is_symmetric=True,
387394
)
395+
elif isinstance(base_config, NVFP4InferenceConfig):
396+
# Note: today the PTQ config does not allow the user to specify
397+
# `per_tensor_scales` due to serialization concerns. In the future
398+
# we may add a way to compute these dynamically (for activations),
399+
# but for now QAT will mimic the existing behavior of not having
400+
# `per_tensor_scales` (subject to change)
401+
if NVFP4MMConfig.DYNAMIC:
402+
act_config = NVFP4FakeQuantizeConfig(False)
403+
else:
404+
act_config = None
405+
weight_config = NVFP4FakeQuantizeConfig(False)
388406
else:
389407
raise ValueError("Unexpected base config: %s" % base_config)
390408
return (act_config, weight_config)

torchao/quantization/qat/fake_quantizer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,18 @@ def __repr__(self) -> str:
5757

5858
@staticmethod
5959
def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":
60+
# TODO: rewrite using registration API so we don't need to import here
61+
from torchao.prototype.qat import (
62+
NVFP4FakeQuantizeConfig,
63+
NVFP4FakeQuantizer,
64+
)
65+
6066
if isinstance(config, IntxFakeQuantizeConfig):
6167
return IntxFakeQuantizer(config)
62-
if isinstance(config, Float8FakeQuantizeConfig):
68+
elif isinstance(config, Float8FakeQuantizeConfig):
6369
return Float8FakeQuantizer(config)
70+
elif isinstance(config, NVFP4FakeQuantizeConfig):
71+
return NVFP4FakeQuantizer(config)
6472
else:
6573
raise ValueError(f"Unknown config type: {config}")
6674

@@ -73,6 +81,7 @@ class Float8FakeQuantizer(FakeQuantizerBase):
7381
def __init__(self, config: Float8FakeQuantizeConfig):
7482
super().__init__()
7583
self.config = config
84+
torch._C._log_api_usage_once("torchao.quantization.qat.Float8FakeQuantizer")
7685

7786
def forward(self, x: torch.Tensor) -> torch.Tensor:
7887
original_dtype = x.dtype
@@ -98,7 +107,7 @@ class IntxFakeQuantizer(FakeQuantizerBase):
98107

99108
def __init__(self, config: IntxFakeQuantizeConfig):
100109
super().__init__()
101-
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizer")
110+
torch._C._log_api_usage_once("torchao.quantization.qat.IntxFakeQuantizer")
102111
self.config = config
103112
self.enabled = True
104113
self.scale: Optional[torch.Tensor] = None

torchao/quantization/qat/linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def __init__(
9292

9393
# initialize weight fake quantizer
9494
if weight_config is not None:
95-
if isinstance(weight_config.granularity, PerGroup):
95+
if isinstance(weight_config, IntxFakeQuantizeConfig) and isinstance(
96+
weight_config.granularity, PerGroup
97+
):
9698
group_size = weight_config.group_size
9799
if group_size is not None and in_features % group_size != 0:
98100
raise ValueError(

0 commit comments

Comments
 (0)