@@ -56,14 +56,27 @@ def __init__(
5656
5757 def get_input_iter (self ):
5858 def args (m , n , k ):
59- a = torch .randn (m , k , device = self .device ).to (torch .float8_e4m3fn )
59+ a = torch .randn (m , k , device = self .device ).to (torch .float16 )
6060 b = (
6161 torch .randn (k , n , device = self .device )
62- .to (torch .float8_e4m3fn )
62+ .to (torch .float16 )
6363 .T .contiguous ()
6464 .T
6565 )
66- return (a , b )
66+
67+ if self .extra_args .scaling_rowwise :
68+ M , N = a .shape [0 ], b .shape [1 ]
69+ scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
70+ scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
71+ else :
72+ scale_a = torch .tensor (1.0 , device = a .device )
73+ scale_b = torch .tensor (1.0 , device = a .device )
74+
75+ # Kernels expect dtype=float8_e4m3fn
76+ a = a .to (torch .float8_e4m3fn )
77+ b = b .to (torch .float8_e4m3fn )
78+
79+ return (a , b , scale_a , scale_b )
6780
6881 if hasattr (self , 'external_shapes' ) and self .external_shapes : # Check for external shapes loaded from input-loader
6982 for shape in self .external_shapes :
@@ -86,62 +99,49 @@ def args(m, n, k):
8699 yield args (m , n , k )
87100
88101 def get_x_val (self , example_inputs ) -> float :
89- a , b = example_inputs
102+ a , b , _ , _ = example_inputs
90103 m , k = a .size ()
91104 _ , n = b .size ()
92105 return (m , n , k )
93106
94- @register_benchmark (baseline = True )
95- def torch_fp8_gemm (self , a , b ):
107+ def _get_out_dtype (self ):
96108 if self .extra_args .scaling_rowwise :
97- M , N = a .shape [0 ], b .shape [1 ]
98- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
99- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
100- out_dtype = torch .bfloat16
109+ return torch .bfloat16
101110 else :
102- scale_a = torch .tensor (1.0 , device = a .device )
103- scale_b = torch .tensor (1.0 , device = a .device )
104- out_dtype = torch .float16
111+ return torch .float16
105112
113+ @register_benchmark (baseline = True )
114+ def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
106115 return lambda : torch ._scaled_mm (
107- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = out_dtype
116+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self . _get_out_dtype ()
108117 )
109118
110119 @register_benchmark ()
111- def pt2_fp8_gemm (self , a , b ) -> Callable :
120+ def pt2_fp8_gemm (self , a , b , scale_a , scale_b ) -> Callable :
112121 torch ._dynamo .reset ()
113122 with inductor_config .patch (
114123 max_autotune = True ,
115124 max_autotune_gemm_backends = "TRITON" ,
116125 autotune_fallback_to_aten = False ,
117126 ):
118- if self .extra_args .scaling_rowwise :
119- M , N = a .shape [0 ], b .shape [1 ]
120- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
121- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
122- out_dtype = torch .bfloat16
123- else :
124- scale_a = torch .tensor (1.0 , device = a .device )
125- scale_b = torch .tensor (1.0 , device = b .device )
126- out_dtype = torch .float16
127127 f = lambda a , b : torch ._scaled_mm (
128- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = out_dtype
128+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self . _get_out_dtype ()
129129 )
130130 compiled = torch .compile (f , dynamic = False )
131131 compiled (a , b )
132132
133133 return lambda : compiled (a , b )
134134
135135 @register_benchmark ()
136- def triton_fp8_gemm (self , a , b ):
136+ def triton_fp8_gemm (self , a , b , scale_a , scale_b ):
137137 return lambda : tutorial_matmul (a , b )
138138
139139 @register_benchmark (enabled = HAS_TMA )
140- def triton_persistent_fp8_gemm (self , a , b ):
140+ def triton_persistent_fp8_gemm (self , a , b , scale_a , scale_b ):
141141 return lambda : matmul_persistent (a , b )
142142
143143 @register_benchmark (enabled = HAS_TMA )
144- def triton_tma_persistent_fp8_gemm (self , a , b ):
144+ def triton_tma_persistent_fp8_gemm (self , a , b , scale_a , scale_b ):
145145 b = b .T .contiguous ()
146146 c , desc_a , desc_b , desc_c = allocate_matmul_tma (a , b )
147147 return lambda : matmul_tma_persistent (a , b , c , desc_a , desc_b , desc_c )
@@ -151,7 +151,7 @@ def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> fl
151151 def nbytes (t ):
152152 return t .numel () * t .element_size ()
153153
154- a , b = example_inputs
154+ a , b , _ , _ = example_inputs
155155 c = fn ()
156156 c = c [0 ] if isinstance (c , tuple ) else c
157157
@@ -164,7 +164,7 @@ def nbytes(t):
164164 def flops (
165165 self , fn_name : str , example_inputs : Any , metrics : BenchmarkOperatorMetrics
166166 ) -> float :
167- a , b = example_inputs
167+ a , b , _ , _ = example_inputs
168168 m , k = a .size ()
169169 _ , n = b .size ()
170170 flops = 2 * m * n * k
0 commit comments