@@ -56,14 +56,27 @@ def __init__(
56
56
57
57
def get_input_iter (self ):
58
58
def args (m , n , k ):
59
- a = torch .randn (m , k , device = self .device ).to (torch .float8_e4m3fn )
59
+ a = torch .randn (m , k , device = self .device ).to (torch .float16 )
60
60
b = (
61
61
torch .randn (k , n , device = self .device )
62
- .to (torch .float8_e4m3fn )
62
+ .to (torch .float16 )
63
63
.T .contiguous ()
64
64
.T
65
65
)
66
- return (a , b )
66
+
67
+ if self .extra_args .scaling_rowwise :
68
+ M , N = a .shape [0 ], b .shape [1 ]
69
+ scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
70
+ scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
71
+ else :
72
+ scale_a = torch .tensor (1.0 , device = a .device )
73
+ scale_b = torch .tensor (1.0 , device = a .device )
74
+
75
+ # Kernels expect dtype=float8_e4m3fn
76
+ a = a .to (torch .float8_e4m3fn )
77
+ b = b .to (torch .float8_e4m3fn )
78
+
79
+ return (a , b , scale_a , scale_b )
67
80
68
81
if hasattr (self , 'external_shapes' ) and self .external_shapes : # Check for external shapes loaded from input-loader
69
82
for shape in self .external_shapes :
@@ -86,62 +99,49 @@ def args(m, n, k):
86
99
yield args (m , n , k )
87
100
88
101
def get_x_val (self , example_inputs ) -> float :
89
- a , b = example_inputs
102
+ a , b , _ , _ = example_inputs
90
103
m , k = a .size ()
91
104
_ , n = b .size ()
92
105
return (m , n , k )
93
106
94
- @register_benchmark (baseline = True )
95
- def torch_fp8_gemm (self , a , b ):
107
+ def _get_out_dtype (self ):
96
108
if self .extra_args .scaling_rowwise :
97
- M , N = a .shape [0 ], b .shape [1 ]
98
- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
99
- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
100
- out_dtype = torch .bfloat16
109
+ return torch .bfloat16
101
110
else :
102
- scale_a = torch .tensor (1.0 , device = a .device )
103
- scale_b = torch .tensor (1.0 , device = a .device )
104
- out_dtype = torch .float16
111
+ return torch .float16
105
112
113
+ @register_benchmark (baseline = True )
114
+ def torch_fp8_gemm (self , a , b , scale_a , scale_b ):
106
115
return lambda : torch ._scaled_mm (
107
- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = out_dtype
116
+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self . _get_out_dtype ()
108
117
)
109
118
110
119
@register_benchmark ()
111
- def pt2_fp8_gemm (self , a , b ) -> Callable :
120
+ def pt2_fp8_gemm (self , a , b , scale_a , scale_b ) -> Callable :
112
121
torch ._dynamo .reset ()
113
122
with inductor_config .patch (
114
123
max_autotune = True ,
115
124
max_autotune_gemm_backends = "TRITON" ,
116
125
autotune_fallback_to_aten = False ,
117
126
):
118
- if self .extra_args .scaling_rowwise :
119
- M , N = a .shape [0 ], b .shape [1 ]
120
- scale_a = torch .ones ((M , 1 ), dtype = torch .float32 , device = a .device )
121
- scale_b = torch .ones ((1 , N ), dtype = torch .float32 , device = b .device )
122
- out_dtype = torch .bfloat16
123
- else :
124
- scale_a = torch .tensor (1.0 , device = a .device )
125
- scale_b = torch .tensor (1.0 , device = b .device )
126
- out_dtype = torch .float16
127
127
f = lambda a , b : torch ._scaled_mm (
128
- a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = out_dtype
128
+ a , b , scale_a , scale_b , use_fast_accum = True , out_dtype = self . _get_out_dtype ()
129
129
)
130
130
compiled = torch .compile (f , dynamic = False )
131
131
compiled (a , b )
132
132
133
133
return lambda : compiled (a , b )
134
134
135
135
@register_benchmark ()
136
- def triton_fp8_gemm (self , a , b ):
136
+ def triton_fp8_gemm (self , a , b , scale_a , scale_b ):
137
137
return lambda : tutorial_matmul (a , b )
138
138
139
139
@register_benchmark (enabled = HAS_TMA )
140
- def triton_persistent_fp8_gemm (self , a , b ):
140
+ def triton_persistent_fp8_gemm (self , a , b , scale_a , scale_b ):
141
141
return lambda : matmul_persistent (a , b )
142
142
143
143
@register_benchmark (enabled = HAS_TMA )
144
- def triton_tma_persistent_fp8_gemm (self , a , b ):
144
+ def triton_tma_persistent_fp8_gemm (self , a , b , scale_a , scale_b ):
145
145
b = b .T .contiguous ()
146
146
c , desc_a , desc_b , desc_c = allocate_matmul_tma (a , b )
147
147
return lambda : matmul_tma_persistent (a , b , c , desc_a , desc_b , desc_c )
@@ -151,7 +151,7 @@ def gbps(self, fn, example_inputs: Any, metrics: BenchmarkOperatorMetrics) -> fl
151
151
def nbytes (t ):
152
152
return t .numel () * t .element_size ()
153
153
154
- a , b = example_inputs
154
+ a , b , _ , _ = example_inputs
155
155
c = fn ()
156
156
c = c [0 ] if isinstance (c , tuple ) else c
157
157
@@ -164,7 +164,7 @@ def nbytes(t):
164
164
def flops (
165
165
self , fn_name : str , example_inputs : Any , metrics : BenchmarkOperatorMetrics
166
166
) -> float :
167
- a , b = example_inputs
167
+ a , b , _ , _ = example_inputs
168
168
m , k = a .size ()
169
169
_ , n = b .size ()
170
170
flops = 2 * m * n * k
0 commit comments