3434 FakeQuantizeConfigBase ,
3535 Float8FakeQuantizeConfig ,
3636 IntxFakeQuantizeConfig ,
37+ NVFP4FakeQuantizeConfig ,
3738)
3839from .utils import (
3940 _fake_quantize_per_channel_group ,
@@ -59,8 +60,10 @@ def __repr__(self) -> str:
5960 def from_config (config : FakeQuantizeConfigBase ) -> "FakeQuantizerBase" :
6061 if isinstance (config , IntxFakeQuantizeConfig ):
6162 return IntxFakeQuantizer (config )
62- if isinstance (config , Float8FakeQuantizeConfig ):
63+ elif isinstance (config , Float8FakeQuantizeConfig ):
6364 return Float8FakeQuantizer (config )
65+ elif isinstance (config , NVFP4FakeQuantizeConfig ):
66+ return NVFP4FakeQuantizer (config )
6467 else :
6568 raise ValueError (f"Unknown config type: { config } " )
6669
@@ -73,6 +76,7 @@ class Float8FakeQuantizer(FakeQuantizerBase):
7376 def __init__ (self , config : Float8FakeQuantizeConfig ):
7477 super ().__init__ ()
7578 self .config = config
79+ torch ._C ._log_api_usage_once ("torchao.quantization.qat.Float8FakeQuantizer" )
7680
7781 def forward (self , x : torch .Tensor ) -> torch .Tensor :
7882 original_dtype = x .dtype
@@ -91,14 +95,60 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9195 return dq
9296
9397
98+ class NVFP4FakeQuantizer (FakeQuantizerBase ):
99+ """
100+ Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
101+ """
102+
103+ def __init__ (self , config : NVFP4FakeQuantizeConfig ):
104+ super ().__init__ ()
105+ torch ._C ._log_api_usage_once ("torchao.quantization.qat.NVFP4FakeQuantizer" )
106+ self .config = config
107+
108+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
109+ from torchao .prototype .mx_formats .nvfp4_tensor import (
110+ _nvfp4_quantize ,
111+ per_tensor_amax_to_scale ,
112+ )
113+
114+ block_size = 16
115+ original_shape = x .shape
116+ if x .dim () == 3 :
117+ x = x .view (- 1 , x .shape [- 1 ])
118+ if self .config .use_per_tensor_scale :
119+ tensor_amax = torch .max (torch .abs (x ))
120+ per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
121+ else :
122+ per_tensor_scale = None
123+
124+ # quantize
125+ scale , q = _nvfp4_quantize (
126+ x ,
127+ block_size = block_size ,
128+ per_tensor_scale = per_tensor_scale ,
129+ skip_dtype_cast_and_packing = True ,
130+ )
131+ if self .config .use_per_tensor_scale :
132+ scale = scale * per_tensor_scale
133+ assert q .dtype == x .dtype
134+ assert scale .dtype == torch .float32
135+
136+ # dequantize
137+ M , K = q .shape [0 ], q .shape [1 ]
138+ q = q .view (M , K // block_size , block_size )
139+ scale = scale .view (M , K // block_size , 1 )
140+ dq = q * scale
141+ return dq .view (original_shape ).to (x .dtype )
142+
143+
94144class IntxFakeQuantizer (FakeQuantizerBase ):
95145 """
96146 Generic module for applying integer fake quantization to a tensor, as specified in the config.
97147 """
98148
99149 def __init__ (self , config : IntxFakeQuantizeConfig ):
100150 super ().__init__ ()
101- torch ._C ._log_api_usage_once ("torchao.quantization.qat.FakeQuantizer " )
151+ torch ._C ._log_api_usage_once ("torchao.quantization.qat.IntxFakeQuantizer " )
102152 self .config = config
103153 self .enabled = True
104154 self .scale : Optional [torch .Tensor ] = None
0 commit comments