2929from .fake_quantize_config import (
3030 FakeQuantizeConfigBase ,
3131 IntxFakeQuantizeConfig ,
32+ NVFP4FakeQuantizeConfig ,
3233)
3334from .utils import (
3435 _fake_quantize_per_channel_group ,
@@ -46,13 +47,14 @@ def __init__(self, config: FakeQuantizeConfigBase):
4647 super ().__init__ ()
4748 self .config = config
4849 self .enabled = True
49- self .scale : Optional [torch .Tensor ] = None
50- self .zero_point : Optional [torch .Tensor ] = None
5150
52- # For range learning only
53- # TODO: make this configurable?
54- self ._scale_eps = 1e-9
55- self ._initialized = False
51+ if isinstance (self .config , IntxFakeQuantizeConfig ):
52+ self .scale : Optional [torch .Tensor ] = None
53+ self .zero_point : Optional [torch .Tensor ] = None
54+ # For range learning only
55+ # TODO: make this configurable?
56+ self ._scale_eps = 1e-9
57+ self ._initialized = False
5658
5759 def forward (self , x : torch .Tensor ) -> torch .Tensor :
5860 """
@@ -62,9 +64,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6264 if not self .enabled :
6365 return x
6466
65- if not isinstance (self .config , IntxFakeQuantizeConfig ):
66- raise ValueError ("Only IntxFakeQuantizeConfig is supported currently" )
67+ if isinstance (self .config , NVFP4FakeQuantizeConfig ):
68+ return self ._nvfp4_forward (x )
69+ elif isinstance (self .config , IntxFakeQuantizeConfig ):
70+ return self ._intx_forward (x )
71+ else :
72+ raise ValueError (f"Unexpected config type { self .config } " )
73+
74+ def _nvfp4_forward (self , x : torch .Tensor ):
75+ """
76+ Apply NVFP4 fake quantization to the tensor following `NVFP4Tensor`.
77+ """
78+ from torchao .prototype .mx_formats .nvfp4_tensor import (
79+ _nvfp4_quantize ,
80+ per_tensor_amax_to_scale ,
81+ )
6782
83+ block_size = 16
84+ if self .config .use_per_tensor_scale :
85+ tensor_amax = torch .max (torch .abs (x ))
86+ per_tensor_scale = per_tensor_amax_to_scale (tensor_amax )
87+ else :
88+ per_tensor_scale = None
89+ scale , q = _nvfp4_quantize (
90+ x ,
91+ block_size = block_size ,
92+ per_tensor_scale = per_tensor_scale ,
93+ skip_dtype_cast_and_packing = True ,
94+ )
95+ assert q .dtype == x .dtype
96+ assert scale .dtype == torch .float32
97+ M , K = q .shape [0 ], q .shape [1 ]
98+ q = q .view (M , K // block_size , block_size )
99+ scale = scale .view (M , K // block_size , 1 )
100+ dq = q * scale
101+ return dq .view (x .shape )
102+
103+ def _intx_forward (self , x : torch .Tensor ) -> torch .Tensor :
104+ """
105+ Apply intx fake quantization to the tensor.
106+ """
68107 if (
69108 self .config .range_learning
70109 and not self ._initialized
@@ -77,15 +116,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
77116 )
78117
79118 if isinstance (self .config .granularity , PerToken ):
80- return self ._per_token_forward (x )
119+ return self ._intx_per_token_forward (x )
81120 elif isinstance (self .config .granularity , (PerAxis , PerGroup )):
82- return self ._per_channel_or_group_forward (x )
121+ return self ._intx_per_channel_or_group_forward (x )
83122 else :
84123 raise ValueError ("Unknown granularity '%s'" % self .config .granularity )
85124
86- def _per_token_forward (self , x : torch .Tensor ) -> torch .Tensor :
125+ def _intx_per_token_forward (self , x : torch .Tensor ) -> torch .Tensor :
87126 """
88- Perform per token fake quantization on the tensor.
127+ Perform intx per token fake quantization on the tensor.
89128 """
90129 if self .config .is_symmetric :
91130 raise NotImplementedError ("Symmetric per token is not supported yet" )
@@ -105,9 +144,9 @@ def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor:
105144 self ._maybe_update_qparams_for_range_learning ()
106145 return _fake_quantize_per_token (x , self .scale , self .zero_point , qmin , qmax )
107146
108- def _per_channel_or_group_forward (self , x : torch .Tensor ) -> torch .Tensor :
147+ def _intx_per_channel_or_group_forward (self , x : torch .Tensor ) -> torch .Tensor :
109148 """
110- Perform per channel or per group fake quantization on the tensor.
149+ Perform intx per channel or per group fake quantization on the tensor.
111150 We express per channel using per group where the group size is the size
112151 of the last dimension of the tensor.
113152 """
0 commit comments