Skip to content

Commit 353c947

Browse files
authored
[HGEMM] Update HGEMM benchmark scripts (#123)
* Update hgemm.py * Update README.md * Update hgemm.py * Update README.md * Update hgemm.py * Update hgemm.py * Update hgemm.py * Update hgemm.py
1 parent 0beec5a commit 353c947

File tree

2 files changed

+72
-60
lines changed

2 files changed

+72
-60
lines changed

hgemm/README.md

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,24 @@ python3 hgemm.py --mma-all --plot --topk 8
6666

6767
## 目前性能
6868

69+
### NVIDIA GeForce RTX 3080 Laptop
70+
71+
在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 WMMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS,使用Windows WSL2 + RTX 3080 Laptop进行测试。
72+
73+
![](./NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png)
74+
75+
```bash
76+
python3 hgemm.py --wmma-all
77+
----------------------------------------------------------------------------------------------------------------------------------
78+
M=16384, N=16384, K=8192, Warmup=5, Iters=20, 27/27
79+
----------------------------------------------------------------------------------------------------------------------------------
80+
(wmma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:96.91984ms, swizzle: NOOP, TFLOPS: 45.38 (+0.00%)
81+
(wmma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:102.8722ms, swizzle: NOOP, TFLOPS: 42.75
82+
(wmma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:85.65800ms, swizzle: 4096, TFLOPS: 51.34 (+13.15%)
83+
(wmma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:95.70884ms, swizzle: 4096, TFLOPS: 45.95
84+
(cublas): ['68.375 ', '-2.234375 '], time:104.2092ms, swizzle: NOOP, TFLOPS: 42.20
85+
----------------------------------------------------------------------------------------------------------------------------------
86+
```
6987
### NVIDIA L20
7088

7189
目前最优的实现,在L20上(理论Tensor Cores FP16算力为 119.5 TFLOPS),使用WMMA API能达到cuBLAS大概95%~98%左右的性能(105-113 TFLOPS vs 105-115 TFLOPS),使用MMA API能达到115 TFLOPS,部分case会超越cuBLAS。已知问题为bank conflicts没有完全消除,目前通过padding的方式缓解bank conflicts会导致shared memory浪费,也会影响SM occupancy。并且尚未手工实现smem swizzle/permute(受限于WMMA API的灵活性以及row major的layout),后续将会尝试通过MMA PTX实现smem swizzle/permute。
@@ -147,24 +165,6 @@ python3 hgemm.py --mma-all --wmma-all --cuda-all
147165
----------------------------------------------------------------------------------------------------------------------------------
148166
```
149167

150-
### NVIDIA GeForce RTX 3080 Laptop
151-
152-
在NVIDIA GeForce RTX 3080 Laptop上测试,使用mma4x4_warp4x4(16 WMMA m16n16k16 ops, warp tile 64x64)以及Thread block swizzle,大部分case能持平甚至超过cuBLAS,不过Laptop由于我是在WSL测试的,性能数据不稳定,这部分看看就好,别太当真。
153-
154-
![](./NVIDIA_GeForce_RTX_3080_Laptop_GPU_WSL2.png)
155-
156-
```bash
157-
python3 hgemm.py --wmma-all
158-
----------------------------------------------------------------------------------------------------------------------------------
159-
M=16384, N=16384, K=8192, Warmup=5, Iters=20, 27/27
160-
----------------------------------------------------------------------------------------------------------------------------------
161-
(wmma4x4+warp4x4+stage3+dsmem): ['68.375 ', '-2.234375 '], time:96.91984ms, swizzle: NOOP, TFLOPS: 45.38 (+0.00%)
162-
(wmma4x4+warp4x4+stage2+dsmem): ['68.375 ', '-2.234375 '], time:102.8722ms, swizzle: NOOP, TFLOPS: 42.75
163-
(wmma4x4+warp4x4+stage3+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:85.65800ms, swizzle: 4096, TFLOPS: 51.34 (+13.15%)
164-
(wmma4x4+warp4x4+stage2+dsmem+swizzle): ['68.375 ', '-2.234375 '], time:95.70884ms, swizzle: 4096, TFLOPS: 45.95
165-
(cublas): ['68.375 ', '-2.234375 '], time:104.2092ms, swizzle: NOOP, TFLOPS: 42.20
166-
----------------------------------------------------------------------------------------------------------------------------------
167-
```
168168

169169
## 性能优化笔记
170170

hgemm/hgemm.py

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77

88
torch.set_grad_enabled(False)
99

10+
1011
def get_args():
1112
parser = argparse.ArgumentParser(description="hgemm benchmark")
1213
parser.add_argument("--M", type=int, default=None, help="Matrix M size")
1314
parser.add_argument("--N", type=int, default=None, help="Matrix N size")
1415
parser.add_argument("--K", type=int, default=None, help="Matrix K size")
1516
parser.add_argument("--MNK", type=int, default=None, help="Matrix M=N=K size")
16-
parser.add_argument("--MMNK", type=int, default=16384, help="Matrix MAX M=M=N=K size")
17-
parser.add_argument("--SEP", type=int, default=512, help="Matrix MAX M=M=N=K size")
17+
parser.add_argument("--MMNK", type=int, default=12800, help="Matrix MAX M=M=N=K size")
18+
parser.add_argument("--SEP", '--sep', type=int, default=256, help="Matrix SEP M=M=N=K size")
1819
parser.add_argument("--warmup", "--w", type=int, default=2, help="Warmup iters")
1920
parser.add_argument("--iters", "--i", type=int, default=10, help="Benchmark iters")
2021
parser.add_argument("--verbose", "--v", action="store_true", help="Verbose")
@@ -36,16 +37,30 @@ def get_args():
3637
parser.add_argument("--no-default", action="store_true", help="Disable default tests")
3738
parser.add_argument("--plot-flops", "--plot", action="store_true", help="Plot TFLOPS")
3839
parser.add_argument("--plot-topk", "--topk", type=int, default=8, help="Plot top k TFLOPS")
39-
parser.add_argument("--no-hint-top1", "--no-hint", action="store_true", help="Hint top 1 TFLOPS")
40+
parser.add_argument("--no-plot-best", "--no-best", action="store_true", help="Not Plot best TFLOPS")
4041
parser.add_argument("--exclude-tags", "--exclude", type=str, default=None, help="Exclude tag for plot, sperated by comma")
41-
parser.add_argument("--save-tag", "--tag", type=str, default=None, help="Save tag for plot")
42+
parser.add_argument("--save-dir", "--dir", type=str, default="./", help="Save dir for plot")
4243
return parser.parse_args()
4344

4445
args = get_args()
4546
print(args)
4647

48+
49+
def get_device_name():
50+
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
51+
# since we will run GPU on WSL2, so add WSL2 tag.
52+
if "Laptop" in device_name:
53+
device_name += " WSL2"
54+
return device_name
55+
56+
57+
def get_device_capability():
58+
return torch.cuda.get_device_capability(torch.cuda.current_device())
59+
60+
4761
# Load the CUDA kernel as a python module
48-
print("Loading hgemm lib ...")
62+
print(f"Loading hgemm lib on device: {get_device_name()}, capability: {get_device_capability()} ...")
63+
4964
lib = load(name='hgemm_lib',
5065
sources=['hgemm.cu', 'hgemm_async.cu', 'hgemm_wmma.cu',
5166
'hgemm_wmma_stage.cu', 'hgemm_cublas.cu',
@@ -184,36 +199,35 @@ def run_benchmark(perf_func: callable,
184199
return out, mean_time
185200

186201

187-
def get_device_name():
188-
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
189-
# we will run GPU on WSL2, so add WSL2 tag.
190-
if "Laptop" in device_name:
191-
device_name += " WSL2"
192-
return device_name
193-
194-
195-
def get_device_capability():
196-
return torch.cuda.get_device_capability(torch.cuda.current_device())
197-
198-
199202
def get_topk_tflops():
200203
topk_tflops = sorted(TOATL_TFLOPS.items(), key=lambda x: x[1],
201204
reverse=True)
202205
print("-" * 130)
203-
print(" " * 42 + f"HGEMM TOTAL TFLOPS, {get_device_name()}")
206+
print(" " * 32 + f"THE TOTAL TFLOPS OF {len(topk_tflops)} HGEMM ALGO ON {get_device_name()} DEVICE")
204207
print("-" * 130)
205208
for tag, tflops in list(topk_tflops)[::-1]:
206-
print(f"{tag:>42}: {tflops:<10.2f} TFLOPS")
207-
print(f"{'(cublas)':>42}: {CUBLAS_TOTAL_TFLOPS:<10.2f} TFLOPS")
209+
print(f"{tag:>45}: {tflops:>20.2f} TFLOPS")
210+
print(f"{'(cublas)':>45}: {CUBLAS_TOTAL_TFLOPS:>20.2f} TFLOPS")
208211
print("-" * 130)
209-
return dict(topk_tflops[:args.plot_topk]).keys()
212+
return list(dict(topk_tflops[:args.plot_topk]).keys())
213+
214+
215+
def get_best_tflops():
216+
all_tflops = []
217+
for tag, tflops in STATIS_INFO.items():
218+
if "cublas" not in tag and "MNK" not in tag:
219+
all_tflops.append(tflops)
220+
# [N, NUM_MNK], reduce max on N dim
221+
all_tflops = torch.tensor(all_tflops, dtype=torch.float)
222+
best_tflops = torch.max(all_tflops, dim=0, keepdim=False)[0].tolist()
223+
return best_tflops
210224

211225

212226
def plot_tflops():
213227
import matplotlib.pyplot as plt
214228
import numpy as np
215-
_, ax = plt.subplots(figsize=(16, 9))
216-
plt.subplots_adjust(left=0.03, right=0.99, top=0.95, bottom=0.05)
229+
ax: plt.Axes = plt.subplots(figsize=(16, 9))[1] # fig, axs
230+
plt.subplots_adjust(left=0.04, right=0.99, top=0.95, bottom=0.05)
217231
ax.set_title(f"My HGEMM vs cuBLAS, {get_device_name()}, Warmup={args.warmup}, Iters={args.iters}")
218232
ax.set_xlabel("M=N=K")
219233
ax.set_ylabel("TFLOPS")
@@ -224,36 +238,37 @@ def plot_tflops():
224238
exclude_tags.append("MNK")
225239
exclude_tags = set(exclude_tags)
226240

227-
def should_exclude(tag: str) -> bool:
241+
topk_tflops = get_topk_tflops()
242+
STATIS_INFO["(best)"] = get_best_tflops()
243+
draw_tags = topk_tflops
244+
draw_tags.append("(cublas)")
245+
draw_tags.append("(best)")
246+
247+
def skip_it(tag: str) -> bool:
228248
for etag in exclude_tags:
229249
if etag in tag:
230250
return True
251+
if tag not in draw_tags:
252+
return True
231253
return False
232254

233-
topk_tflops = get_topk_tflops()
234-
is_top_1 = True
255+
# draw by topk order
235256
for tag, tflops in STATIS_INFO.items():
236-
if (should_exclude(tag)) or (tag not in topk_tflops
237-
and "cublas" not in tag):
257+
if skip_it(tag):
238258
continue
239259
if "cublas" in tag:
240260
ax.plot(tflops, label=tag, linewidth=3)
241261
else:
242-
if is_top_1 and not args.no_hint_top1:
262+
if "best" in tag and not args.no_plot_best:
243263
ax.plot(tflops, label=tag, linewidth=4)
244-
is_top_1 = False
245264
else:
246265
ax.plot(tflops, label=tag, linestyle='--')
247266

248267
ax.legend()
249-
if args.save_tag:
250-
plt.savefig(f"{args.save_tag}", dpi=300)
251-
print(f"plot hgemm TFLOPS done, saved as {args.save_tag}")
252-
else:
253-
device_name = get_device_name().replace(" ", "_")
254-
save_tag = f"{device_name}.png"
255-
plt.savefig(save_tag, dpi=300)
256-
print(f"plot hgemm TFLOPS done, saved as {save_tag}")
268+
device_name = get_device_name().replace(" ", "_")
269+
save_tag = f"{args.save_dir}/{device_name}.png"
270+
plt.savefig(save_tag, dpi=300)
271+
print(f"plot hgemm TFLOPS done, saved as {save_tag}")
257272

258273

259274
def get_mnk(sep: int = args.SEP):
@@ -386,7 +401,4 @@ def get_mnk(sep: int = args.SEP):
386401
print("-" * 130)
387402

388403
if args.plot_flops:
389-
try:
390-
plot_tflops()
391-
except Exception as e:
392-
print(f"plot hgemm TFLOPS failed, {e}")
404+
plot_tflops()

0 commit comments

Comments
 (0)