@@ -749,13 +749,37 @@ def nvfp4_quantize(
749749 AssertionError: If input dtype is not supported, tensor size is not
750750 divisible by block_size, tensor is not contiguous, or block_size != 16
751751 """
752+ return _nvfp4_quantize (data_hp , block_size , per_tensor_scale )
753+
754+
755+ class _Float8Round (torch .autograd .Function ):
756+ """
757+ Cast a tensor to float8 and back to float32 with backward STE.
758+ """
759+
760+ @staticmethod
761+ def forward (ctx , x : torch .Tensor ) -> torch .Tensor :
762+ return x .to (torch .float8_e4m3fn ).to (torch .float32 )
763+
764+ @staticmethod
765+ def backward (ctx , gy : torch .Tensor ) -> torch .Tensor :
766+ return gy
767+
768+
769+ def _nvfp4_quantize (
770+ data_hp : torch .Tensor ,
771+ block_size : int = 16 ,
772+ per_tensor_scale : Optional [torch .Tensor ] = None ,
773+ skip_dtype_cast_and_packing : bool = False ,
774+ ) -> tuple [torch .Tensor , torch .Tensor ]:
752775 assert data_hp .dtype in (torch .bfloat16 , torch .float ), (
753776 f"{ data_hp .dtype } not supported"
754777 )
755778 assert data_hp .size (- 1 ) % block_size == 0 , "K dim must be divisible by block_size"
756779 assert data_hp .is_contiguous (), "Only support contiguous data for now"
757780 assert block_size == 16 , "NVFP4 requires block_size=16"
758781
782+ orig_dtype = data_hp .dtype
759783 orig_shape = data_hp .shape
760784 # Convert to float32 early for consistent precision with Triton implementation
761785 data_hp = data_hp .float ().reshape (orig_shape [0 ], - 1 , block_size )
@@ -767,10 +791,8 @@ def nvfp4_quantize(
767791 out_scales = None
768792 if per_tensor_scale is None :
769793 # We are doing single level scaling
770- block_scale_fp8 = torch .clamp (block_scale , min = E4M3_EPS , max = F8E4M3_MAX ).to (
771- torch .float8_e4m3fn
772- )
773- block_scale_fp32 = block_scale_fp8 .to (torch .float32 )
794+ block_scale_fp8 = torch .clamp (block_scale , min = E4M3_EPS , max = F8E4M3_MAX )
795+ block_scale_fp32 = _Float8Round .apply (block_scale_fp8 )
774796 data_scaled = data_hp / block_scale_fp32 .unsqueeze (- 1 )
775797 out_scales = block_scale_fp8
776798 else :
@@ -782,8 +804,8 @@ def nvfp4_quantize(
782804 scaled_block_scales = block_scale_fp32 / per_tensor_scale
783805 scaled_block_scales_fp8 = torch .clamp (
784806 scaled_block_scales , min = E4M3_EPS , max = F8E4M3_MAX
785- ). to ( torch . float8_e4m3fn )
786- scaled_block_scales_fp32 = scaled_block_scales_fp8 . to ( torch . float32 )
807+ )
808+ scaled_block_scales_fp32 = _Float8Round . apply ( scaled_block_scales_fp8 )
787809 # We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
788810 # To apply to data
789811 total_scale = per_tensor_scale * scaled_block_scales_fp32
@@ -792,8 +814,11 @@ def nvfp4_quantize(
792814
793815 data_scaled = torch .clamp (data_scaled , - F4_E2M1_MAX , F4_E2M1_MAX )
794816 data_scaled = data_scaled .view (orig_shape )
795- data_lp = f32_to_f4_unpacked (data_scaled )
796- # TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
797- # data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
798- data_lp = pack_uint4 (data_lp )
799- return out_scales , data_lp
817+ if skip_dtype_cast_and_packing :
818+ return out_scales .to (torch .float32 ), data_scaled .to (orig_dtype )
819+ else :
820+ data_lp = f32_to_f4_unpacked (data_scaled )
821+ # TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
822+ # data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
823+ data_lp = pack_uint4 (data_lp )
824+ return out_scales .to (torch .float8_e4m3fn ), data_lp
0 commit comments