@@ -41,6 +41,8 @@ def parse_args(args):
4141    parser .add_argument ("--m" , type = int )
4242    parser .add_argument ("--k" , type = int )
4343    parser .add_argument ("--n" , type = int )
44+     parser .add_argument ("--per-tensor-scale-a" , type = float , default = None )
45+     parser .add_argument ("--per-tensor-scale-b" , type = float , default = None )
4446    return  parser .parse_args (args )
4547
4648
@@ -54,18 +56,53 @@ def __init__(
5456        super ().__init__ (tb_args , extra_args )
5557        self .extra_args  =  parse_args (extra_args )
5658
59+     def  _get_dtype (self ):
60+         if  self .extra_args .scaling_rowwise :
61+             return  torch .bfloat16 
62+         else :
63+             return  torch .float16 
64+ 
5765    def  get_input_iter (self ):
66+         def  _get_scale_per_tensor (
67+             x : torch .Tensor , custom_scale : float  =  None 
68+         ) ->  torch .Tensor :
69+             # For tensor-wise scaling, kernel requires a float32 scale tensor 
70+             if  custom_scale :
71+                 return  torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
72+             scale  =  torch .finfo (torch .float8_e4m3fn ).max  /  x .abs ().max ()
73+             return  scale .to (torch .float32 )
74+ 
75+         def  _get_scale_per_row (
76+             x : torch .Tensor , transpose : bool  =  False 
77+         ) ->  torch .Tensor :
78+             if  transpose :  # scale_b.shape should be [1, N] 
79+                 scale  =  (
80+                     torch .finfo (torch .float8_e4m3fn ).max 
81+                     /  x .abs ().max (dim = 0 , keepdim = True ).values 
82+                 )
83+             else :  # scale_a.shape should be [M, 1] 
84+                 scale  =  (
85+                     torch .finfo (torch .float8_e4m3fn ).max 
86+                     /  x .abs ().max (dim = 1 , keepdim = True ).values 
87+                 )
88+             return  scale .to (
89+                 torch .float32 
90+             )  # For row-wise scaling, kernel requires a float32 scale tensor 
91+ 
5892        def  args (m , n , k ):
5993            a  =  torch .randn (m , k , device = self .device ).to (torch .float16 )
6094            b  =  torch .randn (k , n , device = self .device ).to (torch .float16 ).T .contiguous ().T 
6195
6296            if  self .extra_args .scaling_rowwise :
63-                 M , N  =  a .shape [0 ], b .shape [1 ]
64-                 scale_a  =  torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
65-                 scale_b  =  torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
97+                 scale_a  =  _get_scale_per_row (a )
98+                 scale_b  =  _get_scale_per_row (b , transpose = True )
6699            else :
67-                 scale_a  =  torch .tensor (1.0 , device = a .device )
68-                 scale_b  =  torch .tensor (1.0 , device = a .device )
100+                 scale_a  =  _get_scale_per_tensor (
101+                     a , custom_scale = self .extra_args .per_tensor_scale_a 
102+                 )
103+                 scale_b  =  _get_scale_per_tensor (
104+                     b , custom_scale = self .extra_args .per_tensor_scale_b 
105+                 )
69106
70107            # Kernels expect dtype=float8_e4m3fn 
71108            a  =  a .to (torch .float8_e4m3fn )
@@ -103,16 +140,10 @@ def get_x_val(self, example_inputs) -> float:
103140        _ , n  =  b .size ()
104141        return  (m , n , k )
105142
106-     def  _get_out_dtype (self ):
107-         if  self .extra_args .scaling_rowwise :
108-             return  torch .bfloat16 
109-         else :
110-             return  torch .float16 
111- 
112143    @register_benchmark (baseline = True ) 
113144    def  torch_fp8_gemm (self , a , b , scale_a , scale_b ):
114145        return  lambda : torch ._scaled_mm (
115-             a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_out_dtype ()
146+             a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_dtype ()
116147        )
117148
118149    @register_benchmark () 
@@ -129,7 +160,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
129160                scale_a ,
130161                scale_b ,
131162                use_fast_accum = True ,
132-                 out_dtype = self ._get_out_dtype (),
163+                 out_dtype = self ._get_dtype (),
133164            )
134165            compiled  =  torch .compile (f , dynamic = False )
135166            compiled (a , b )
0 commit comments