@@ -41,6 +41,8 @@ def parse_args(args):
4141 parser .add_argument ("--m" , type = int )
4242 parser .add_argument ("--k" , type = int )
4343 parser .add_argument ("--n" , type = int )
44+ parser .add_argument ("--per-tensor-scale-a" , type = float , default = None )
45+ parser .add_argument ("--per-tensor-scale-b" , type = float , default = None )
4446 return parser .parse_args (args )
4547
4648
@@ -54,18 +56,53 @@ def __init__(
5456 super ().__init__ (tb_args , extra_args )
5557 self .extra_args = parse_args (extra_args )
5658
59+ def _get_dtype (self ):
60+ if self .extra_args .scaling_rowwise :
61+ return torch .bfloat16
62+ else :
63+ return torch .float16
64+
5765 def get_input_iter (self ):
66+ def _get_scale_per_tensor (
67+ x : torch .Tensor , custom_scale : float = None
68+ ) -> torch .Tensor :
69+ # For tensor-wise scaling, kernel requires a float32 scale tensor
70+ if custom_scale :
71+ return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
72+ scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
73+ return scale .to (torch .float32 )
74+
75+ def _get_scale_per_row (
76+ x : torch .Tensor , transpose : bool = False
77+ ) -> torch .Tensor :
78+ if transpose : # scale_b.shape should be [1, N]
79+ scale = (
80+ torch .finfo (torch .float8_e4m3fn ).max
81+ / x .abs ().max (dim = 0 , keepdim = True ).values
82+ )
83+ else : # scale_a.shape should be [M, 1]
84+ scale = (
85+ torch .finfo (torch .float8_e4m3fn ).max
86+ / x .abs ().max (dim = 1 , keepdim = True ).values
87+ )
88+ return scale .to (
89+ torch .float32
90+ ) # For row-wise scaling, kernel requires a float32 scale tensor
91+
5892 def args (m , n , k ):
5993 a = torch .randn (m , k , device = self .device ).to (torch .float16 )
6094 b = torch .randn (k , n , device = self .device ).to (torch .float16 ).T .contiguous ().T
6195
6296 if self .extra_args .scaling_rowwise :
63- M , N = a .shape [0 ], b .shape [1 ]
64- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
65- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
97+ scale_a = _get_scale_per_row (a )
98+ scale_b = _get_scale_per_row (b , transpose = True )
6699 else :
67- scale_a = torch .tensor (1.0 , device = a .device )
68- scale_b = torch .tensor (1.0 , device = a .device )
100+ scale_a = _get_scale_per_tensor (
101+ a , custom_scale = self .extra_args .per_tensor_scale_a
102+ )
103+ scale_b = _get_scale_per_tensor (
104+ b , custom_scale = self .extra_args .per_tensor_scale_b
105+ )
69106
70107 # Kernels expect dtype=float8_e4m3fn
71108 a = a .to (torch .float8_e4m3fn )
@@ -103,16 +140,10 @@ def get_x_val(self, example_inputs) -> float:
103140 _ , n = b .size ()
104141 return (m , n , k )
105142
106- def _get_out_dtype (self ):
107- if self .extra_args .scaling_rowwise :
108- return torch .bfloat16
109- else :
110- return torch .float16
111-
112143 @register_benchmark (baseline = True )
113144 def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
114145 return lambda : torch ._scaled_mm (
115- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_out_dtype ()
146+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_dtype ()
116147 )
117148
118149 @register_benchmark ()
@@ -129,7 +160,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
129160 scale_a ,
130161 scale_b ,
131162 use_fast_accum = True ,
132- out_dtype = self ._get_out_dtype (),
163+ out_dtype = self ._get_dtype (),
133164 )
134165 compiled = torch .compile (f , dynamic = False )
135166 compiled (a , b )
0 commit comments