@@ -56,14 +56,27 @@ def __init__(
56
56
57
57
def get_input_iter (self ):
58
58
def args (m , n , k ):
59
- a = torch .randn (m , k , device = self .device ).to (torch .float8_e4m3fn )
59
+ a = torch .randn (m , k , device = self .device ).to (torch .float16 )
60
60
b = (
61
61
torch .randn (k , n , device = self .device )
62
- .to (torch .float8_e4m3fn )
62
+ .to (torch .float16 )
63
63
.T .contiguous ()
64
64
.T
65
65
)
66
- return (a , b )
66
+
67
+ if self .extra_args .scaling_rowwise :
68
+ M , N = a .shape [0 ], b .shape [1 ]
69
+ scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
70
+ scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
71
+ else :
72
+ scale_a = torch .tensor (1.0 , device = a .device )
73
+ scale_b = torch .tensor (1.0 , device = a .device )
74
+
75
+ # Kernels expect dtype=float8_e4m3fn
76
+ a = a .to (torch .float8_e4m3fn )
77
+ b = b .to (torch .float8_e4m3fn )
78
+
79
+ return (a , b , scale_a , scale_b )
67
80
68
81
if (
69
82
hasattr (self , "external_shapes" ) and self .external_shapes
@@ -90,62 +103,49 @@ def args(m, n, k):
90
103
yield args (m , n , k )
91
104
92
105
def get_x_val (self , example_inputs ) -> float :
93
- a , b = example_inputs
106
+ a , b , _ , _ = example_inputs
94
107
m , k = a .size ()
95
108
_ , n = b .size ()
96
109
return (m , n , k )
97
110
98
- @register_benchmark (baseline = True )
99
- def torch_fp8_gemm (self , a , b ):
111
+ def _get_out_dtype (self ):
100
112
if self .extra_args .scaling_rowwise :
101
- M , N = a .shape [0 ], b .shape [1 ]
102
- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
103
- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
104
- out_dtype = torch .bfloat16
113
+ return torch .bfloat16
105
114
else :
106
- scale_a = torch .tensor (1.0 , device = a .device )
107
- scale_b = torch .tensor (1.0 , device = a .device )
108
- out_dtype = torch .float16
115
+ return torch .float16
109
116
117
+ @register_benchmark (baseline = True )
118
+ def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
110
119
return lambda : torch ._scaled_mm (
111
- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = out_dtype
120
+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self . _get_out_dtype ()
112
121
)
113
122
114
123
@register_benchmark ()
115
- def pt2_fp8_gemm (self , a , b ) -> Callable :
124
+ def pt2_fp8_gemm (self , a , b , scale_a , scale_b ) -> Callable :
116
125
torch ._dynamo .reset ()
117
126
with inductor_config .patch (
118
127
max_autotune = True ,
119
128
max_autotune_gemm_backends = "TRITON" ,
120
129
autotune_fallback_to_aten = False ,
121
130
):
122
- if self .extra_args .scaling_rowwise :
123
- M , N = a .shape [0 ], b .shape [1 ]
124
- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
125
- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
126
- out_dtype = torch .bfloat16
127
- else :
128
- scale_a = torch .tensor (1.0 , device = a .device )
129
- scale_b = torch .tensor (1.0 , device = b .device )
130
- out_dtype = torch .float16
131
131
f = lambda a , b : torch ._scaled_mm (
132
- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = out_dtype
132
+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self . _get_out_dtype ()
133
133
)
134
134
compiled = torch .compile (f , dynamic = False )
135
135
compiled (a , b )
136
136
137
137
return lambda : compiled (a , b )
138
138
139
139
@register_benchmark ()
140
- def triton_fp8_gemm (self , a , b ):
140
+ def triton_fp8_gemm (self , a , b , scale_a , scale_b ):
141
141
return lambda : tutorial_matmul (a , b )
142
142
143
143
@register_benchmark (enabled = HAS_TMA )
144
- def triton_persistent_fp8_gemm (self , a , b ):
144
+ def triton_persistent_fp8_gemm (self , a , b , scale_a , scale_b ):
145
145
return lambda : matmul_persistent (a , b )
146
146
147
147
@register_benchmark (enabled = HAS_TMA )
148
- def triton_tma_persistent_fp8_gemm (self , a , b ):
148
+ def triton_tma_persistent_fp8_gemm (self , a , b , scale_a , scale_b ):
149
149
b = b .T .contiguous ()
150
150
c , desc_a , desc_b , desc_c = allocate_matmul_tma (a , b )
151
151
return lambda : matmul_tma_persistent (a , b , c , desc_a , desc_b , desc_c )
@@ -155,7 +155,7 @@ def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> fl
155
155
def nbytes (t ):
156
156
return t .numel () * t .element_size ()
157
157
158
- a , b = example_inputs
158
+ a , b , _ , _ = example_inputs
159
159
c = fn ()
160
160
c = c [0 ] if isinstance (c , tuple ) else c
161
161
@@ -168,7 +168,7 @@ def nbytes(t):
168
168
def flops (
169
169
self , fn_name : str , example_inputs : Any , metrics : BenchmarkOperatorMetrics
170
170
) -> float :
171
- a , b = example_inputs
171
+ a , b , _ , _ = example_inputs
172
172
m , k = a .size ()
173
173
_ , n = b .size ()
174
174
flops = 2 * m * n * k
0 commit comments