@@ -41,6 +41,8 @@ def parse_args(args):
4141 parser .add_argument ("--m" , type = int )
4242 parser .add_argument ("--k" , type = int )
4343 parser .add_argument ("--n" , type = int )
44+ parser .add_argument ("--per-tensor-scale-a" , type = float , default = None )
45+ parser .add_argument ("--per-tensor-scale-b" , type = float , default = None )
4446 return parser .parse_args (args )
4547
4648
@@ -54,12 +56,25 @@ def __init__(
5456 super ().__init__ (tb_args , extra_args )
5557 self .extra_args = parse_args (extra_args )
5658
59+ def _get_dtype (self ):
60+ if self .extra_args .scaling_rowwise :
61+ return torch .bfloat16
62+ else :
63+ return torch .float16
64+
5765 def get_input_iter (self ):
66+ def _get_scale_per_tensor (x : torch .Tensor , custom_scale : float = None ) -> torch .Tensor :
67+ # For tensor-wise scaling, kernel requires a float32 scale tensor
68+ if custom_scale :
69+ return torch .tensor (custom_scale , dtype = torch .float32 , device = x .device )
70+ scale = torch .finfo (torch .float8_e4m3fn ).max / x .abs ().max ()
71+ return scale .to (torch .float32 )
72+
5873 def args (m , n , k ):
59- a = torch .randn (m , k , device = self .device ).to (torch . float16 )
74+ a = torch .randn (m , k , device = self .device ).to (self . _get_dtype () )
6075 b = (
6176 torch .randn (k , n , device = self .device )
62- .to (torch . float16 )
77+ .to (self . _get_dtype () )
6378 .T .contiguous ()
6479 .T
6580 )
@@ -69,8 +84,8 @@ def args(m, n, k):
6984 scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
7085 scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
7186 else :
72- scale_a = torch . tensor ( 1.0 , device = a . device )
73- scale_b = torch . tensor ( 1.0 , device = a . device )
87+ scale_a = _get_scale_per_tensor ( a , custom_scale = self . extra_args . per_tensor_scale_a )
88+ scale_b = _get_scale_per_tensor ( b , custom_scale = self . extra_args . per_tensor_scale_b )
7489
7590 # Kernels expect dtype=float8_e4m3fn
7691 a = a .to (torch .float8_e4m3fn )
@@ -108,16 +123,10 @@ def get_x_val(self, example_inputs) -> float:
108123 _ , n = b .size ()
109124 return (m , n , k )
110125
111- def _get_out_dtype (self ):
112- if self .extra_args .scaling_rowwise :
113- return torch .bfloat16
114- else :
115- return torch .float16
116-
117126 @register_benchmark (baseline = True )
118127 def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
119128 return lambda : torch ._scaled_mm (
120- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_out_dtype ()
129+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_dtype ()
121130 )
122131
123132 @register_benchmark ()
@@ -129,7 +138,7 @@ def pt2_fp8_gemm(self, a, b, scale_a, scale_b) -> Callable:
129138 autotune_fallback_to_aten = False ,
130139 ):
131140 f = lambda a , b : torch ._scaled_mm (
132- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_out_dtype ()
141+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self ._get_dtype ()
133142 )
134143 compiled = torch .compile (f , dynamic = False )
135144 compiled (a , b )
0 commit comments