7
7
8
8
torch .set_grad_enabled (False )
9
9
10
+
10
11
def get_args ():
11
12
parser = argparse .ArgumentParser (description = "hgemm benchmark" )
12
13
parser .add_argument ("--M" , type = int , default = None , help = "Matrix M size" )
13
14
parser .add_argument ("--N" , type = int , default = None , help = "Matrix N size" )
14
15
parser .add_argument ("--K" , type = int , default = None , help = "Matrix K size" )
15
16
parser .add_argument ("--MNK" , type = int , default = None , help = "Matrix M=N=K size" )
16
- parser .add_argument ("--MMNK" , type = int , default = 16384 , help = "Matrix MAX M=M=N=K size" )
17
- parser .add_argument ("--SEP" , type = int , default = 512 , help = "Matrix MAX M=M=N=K size" )
17
+ parser .add_argument ("--MMNK" , type = int , default = 12800 , help = "Matrix MAX M=M=N=K size" )
18
+ parser .add_argument ("--SEP" , '--sep' , type = int , default = 256 , help = "Matrix SEP M=M=N=K size" )
18
19
parser .add_argument ("--warmup" , "--w" , type = int , default = 2 , help = "Warmup iters" )
19
20
parser .add_argument ("--iters" , "--i" , type = int , default = 10 , help = "Benchmark iters" )
20
21
parser .add_argument ("--verbose" , "--v" , action = "store_true" , help = "Verbose" )
@@ -36,16 +37,30 @@ def get_args():
36
37
parser .add_argument ("--no-default" , action = "store_true" , help = "Disable default tests" )
37
38
parser .add_argument ("--plot-flops" , "--plot" , action = "store_true" , help = "Plot TFLOPS" )
38
39
parser .add_argument ("--plot-topk" , "--topk" , type = int , default = 8 , help = "Plot top k TFLOPS" )
39
- parser .add_argument ("--no-hint-top1 " , "--no-hint " , action = "store_true" , help = "Hint top 1 TFLOPS" )
40
+ parser .add_argument ("--no-plot-best " , "--no-best " , action = "store_true" , help = "Not Plot best TFLOPS" )
40
41
parser .add_argument ("--exclude-tags" , "--exclude" , type = str , default = None , help = "Exclude tag for plot, sperated by comma" )
41
- parser .add_argument ("--save-tag " , "--tag " , type = str , default = None , help = "Save tag for plot" )
42
+ parser .add_argument ("--save-dir " , "--dir " , type = str , default = "./" , help = "Save dir for plot" )
42
43
return parser .parse_args ()
43
44
44
45
args = get_args ()
45
46
print (args )
46
47
48
+
49
+ def get_device_name ():
50
+ device_name = torch .cuda .get_device_name (torch .cuda .current_device ())
51
+ # since we will run GPU on WSL2, so add WSL2 tag.
52
+ if "Laptop" in device_name :
53
+ device_name += " WSL2"
54
+ return device_name
55
+
56
+
57
+ def get_device_capability ():
58
+ return torch .cuda .get_device_capability (torch .cuda .current_device ())
59
+
60
+
47
61
# Load the CUDA kernel as a python module
48
- print ("Loading hgemm lib ..." )
62
+ print (f"Loading hgemm lib on device: { get_device_name ()} , capability: { get_device_capability ()} ..." )
63
+
49
64
lib = load (name = 'hgemm_lib' ,
50
65
sources = ['hgemm.cu' , 'hgemm_async.cu' , 'hgemm_wmma.cu' ,
51
66
'hgemm_wmma_stage.cu' , 'hgemm_cublas.cu' ,
@@ -184,36 +199,35 @@ def run_benchmark(perf_func: callable,
184
199
return out , mean_time
185
200
186
201
187
- def get_device_name ():
188
- device_name = torch .cuda .get_device_name (torch .cuda .current_device ())
189
- # we will run GPU on WSL2, so add WSL2 tag.
190
- if "Laptop" in device_name :
191
- device_name += " WSL2"
192
- return device_name
193
-
194
-
195
- def get_device_capability ():
196
- return torch .cuda .get_device_capability (torch .cuda .current_device ())
197
-
198
-
199
202
def get_topk_tflops ():
200
203
topk_tflops = sorted (TOATL_TFLOPS .items (), key = lambda x : x [1 ],
201
204
reverse = True )
202
205
print ("-" * 130 )
203
- print (" " * 42 + f"HGEMM TOTAL TFLOPS, { get_device_name ()} " )
206
+ print (" " * 32 + f"THE TOTAL TFLOPS OF { len ( topk_tflops ) } HGEMM ALGO ON { get_device_name ()} DEVICE " )
204
207
print ("-" * 130 )
205
208
for tag , tflops in list (topk_tflops )[::- 1 ]:
206
- print (f"{ tag :>42 } : { tflops :<10 .2f} TFLOPS" )
207
- print (f"{ '(cublas)' :>42 } : { CUBLAS_TOTAL_TFLOPS :<10 .2f} TFLOPS" )
209
+ print (f"{ tag :>45 } : { tflops :>20 .2f} TFLOPS" )
210
+ print (f"{ '(cublas)' :>45 } : { CUBLAS_TOTAL_TFLOPS :>20 .2f} TFLOPS" )
208
211
print ("-" * 130 )
209
- return dict (topk_tflops [:args .plot_topk ]).keys ()
212
+ return list (dict (topk_tflops [:args .plot_topk ]).keys ())
213
+
214
+
215
+ def get_best_tflops ():
216
+ all_tflops = []
217
+ for tag , tflops in STATIS_INFO .items ():
218
+ if "cublas" not in tag and "MNK" not in tag :
219
+ all_tflops .append (tflops )
220
+ # [N, NUM_MNK], reduce max on N dim
221
+ all_tflops = torch .tensor (all_tflops , dtype = torch .float )
222
+ best_tflops = torch .max (all_tflops , dim = 0 , keepdim = False )[0 ].tolist ()
223
+ return best_tflops
210
224
211
225
212
226
def plot_tflops ():
213
227
import matplotlib .pyplot as plt
214
228
import numpy as np
215
- _ , ax = plt .subplots (figsize = (16 , 9 ))
216
- plt .subplots_adjust (left = 0.03 , right = 0.99 , top = 0.95 , bottom = 0.05 )
229
+ ax : plt . Axes = plt .subplots (figsize = (16 , 9 ))[ 1 ] # fig, axs
230
+ plt .subplots_adjust (left = 0.04 , right = 0.99 , top = 0.95 , bottom = 0.05 )
217
231
ax .set_title (f"My HGEMM vs cuBLAS, { get_device_name ()} , Warmup={ args .warmup } , Iters={ args .iters } " )
218
232
ax .set_xlabel ("M=N=K" )
219
233
ax .set_ylabel ("TFLOPS" )
@@ -224,36 +238,37 @@ def plot_tflops():
224
238
exclude_tags .append ("MNK" )
225
239
exclude_tags = set (exclude_tags )
226
240
227
- def should_exclude (tag : str ) -> bool :
241
+ topk_tflops = get_topk_tflops ()
242
+ STATIS_INFO ["(best)" ] = get_best_tflops ()
243
+ draw_tags = topk_tflops
244
+ draw_tags .append ("(cublas)" )
245
+ draw_tags .append ("(best)" )
246
+
247
+ def skip_it (tag : str ) -> bool :
228
248
for etag in exclude_tags :
229
249
if etag in tag :
230
250
return True
251
+ if tag not in draw_tags :
252
+ return True
231
253
return False
232
254
233
- topk_tflops = get_topk_tflops ()
234
- is_top_1 = True
255
+ # draw by topk order
235
256
for tag , tflops in STATIS_INFO .items ():
236
- if (should_exclude (tag )) or (tag not in topk_tflops
237
- and "cublas" not in tag ):
257
+ if skip_it (tag ):
238
258
continue
239
259
if "cublas" in tag :
240
260
ax .plot (tflops , label = tag , linewidth = 3 )
241
261
else :
242
- if is_top_1 and not args .no_hint_top1 :
262
+ if "best" in tag and not args .no_plot_best :
243
263
ax .plot (tflops , label = tag , linewidth = 4 )
244
- is_top_1 = False
245
264
else :
246
265
ax .plot (tflops , label = tag , linestyle = '--' )
247
266
248
267
ax .legend ()
249
- if args .save_tag :
250
- plt .savefig (f"{ args .save_tag } " , dpi = 300 )
251
- print (f"plot hgemm TFLOPS done, saved as { args .save_tag } " )
252
- else :
253
- device_name = get_device_name ().replace (" " , "_" )
254
- save_tag = f"{ device_name } .png"
255
- plt .savefig (save_tag , dpi = 300 )
256
- print (f"plot hgemm TFLOPS done, saved as { save_tag } " )
268
+ device_name = get_device_name ().replace (" " , "_" )
269
+ save_tag = f"{ args .save_dir } /{ device_name } .png"
270
+ plt .savefig (save_tag , dpi = 300 )
271
+ print (f"plot hgemm TFLOPS done, saved as { save_tag } " )
257
272
258
273
259
274
def get_mnk (sep : int = args .SEP ):
@@ -386,7 +401,4 @@ def get_mnk(sep: int = args.SEP):
386
401
print ("-" * 130 )
387
402
388
403
if args .plot_flops :
389
- try :
390
- plot_tflops ()
391
- except Exception as e :
392
- print (f"plot hgemm TFLOPS failed, { e } " )
404
+ plot_tflops ()
0 commit comments