Skip to content

Commit c621ce5

Browse files
[moe training] add bench script for fp8 rowwise kernels and update autotune configs
stack-info: PR: #2697, branch: danielvegamyhre/stack/31
1 parent ef4e25c commit c621ce5

File tree

5 files changed

+166
-13
lines changed

5 files changed

+166
-13
lines changed

test/prototype/moe_training/test_kernels.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,6 @@ def test_fp8_rowwise_3d_transpose_rhs(round_scales_to_power_of_2: bool):
109109
target_dtype=torch.float8_e4m3fn,
110110
round_scales_to_power_of_2=round_scales_to_power_of_2,
111111
)
112-
# Pytorch impl keeps the empty scaled dimension, so we need to squeeze it out.
113-
ref_scales = ref_scales.squeeze(1)
114112

115113
triton_fp8, triton_scales = triton_fp8_rowwise_3d_transpose_rhs(
116114
x,
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
from triton.testing import do_bench
16+
17+
from torchao.prototype.moe_training.kernels.float8_rowwise import (
18+
triton_fp8_rowwise_3d_transpose_rhs,
19+
)
20+
from torchao.prototype.moe_training.utils import (
21+
torch_to_3d_rowwise_float8_transpose_rhs,
22+
)
23+
24+
device = torch.device("cuda")
25+
26+
# Needed since changing args to function causes recompiles
27+
torch._dynamo.config.cache_size_limit = 1000
28+
29+
30+
@dataclass(frozen=True)
31+
class ExperimentConfig:
32+
high_precision_dtype: torch.dtype
33+
input_shape: tuple[int]
34+
35+
36+
@dataclass(frozen=True)
37+
class ExperimentResult:
38+
torch_time_us: float
39+
triton_time_us: float
40+
41+
42+
@dataclass(frozen=True)
43+
class Experiment:
44+
config: ExperimentConfig
45+
result: ExperimentResult
46+
47+
48+
def get_configs() -> List[ExperimentConfig]:
49+
# Llama4 and DeepSeekV3 shapes
50+
input_shapes = [(8, 4096, 1024), (16, 5120 * 4, 5120)]
51+
high_precision_dtypes = [torch.bfloat16]
52+
configs = []
53+
for input_shape, high_precision_dtype in itertools.product(
54+
input_shapes, high_precision_dtypes
55+
):
56+
configs.append(
57+
ExperimentConfig(
58+
input_shape=input_shape,
59+
high_precision_dtype=high_precision_dtype,
60+
)
61+
)
62+
return configs
63+
64+
65+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
66+
# Expert weights will be passed in transposed and column major in practice
67+
input_tensor = torch.randn(
68+
*config.input_shape,
69+
dtype=config.high_precision_dtype,
70+
device=device,
71+
).transpose(-2, -1)
72+
73+
def warmup(func, *args, **kwargs):
74+
for _ in range(10):
75+
func(*args, **kwargs)
76+
77+
def run_torch(input_tensor: torch.Tensor):
78+
out = torch_to_3d_rowwise_float8_transpose_rhs(
79+
input_tensor,
80+
target_dtype=torch.float8_e4m3fn,
81+
round_scales_to_power_of_2=True,
82+
)
83+
torch.cuda.synchronize()
84+
return out
85+
86+
def run_triton(input_tensor: torch.Tensor):
87+
_ = triton_fp8_rowwise_3d_transpose_rhs(
88+
input_tensor,
89+
output_dtype=torch.float8_e4m3fn,
90+
round_scales_to_power_of_2=True,
91+
)
92+
torch.cuda.synchronize()
93+
94+
# bench torch
95+
compiled_run_torch = torch.compile(run_torch)
96+
warmup(run_torch, input_tensor)
97+
torch_time_us = benchmark_cuda_function_in_microseconds(
98+
compiled_run_torch,
99+
input_tensor,
100+
)
101+
102+
# bench triton
103+
warmup(run_triton, input_tensor)
104+
triton_time_us = benchmark_cuda_function_in_microseconds(
105+
run_triton,
106+
input_tensor,
107+
)
108+
109+
return ExperimentResult(
110+
torch_time_us=torch_time_us,
111+
triton_time_us=triton_time_us,
112+
)
113+
114+
115+
def print_results(experiments: List[Experiment]):
116+
headers = [
117+
"input_shape",
118+
"torch_time_us",
119+
"triton_time_us",
120+
]
121+
rows = []
122+
for experiment in experiments:
123+
input_shape = f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1], experiment.config.input_shape[2]})"
124+
rows.append(
125+
[
126+
input_shape,
127+
experiment.result.torch_time_us,
128+
experiment.result.triton_time_us,
129+
]
130+
)
131+
print(tabulate(rows, headers=headers))
132+
133+
134+
def benchmark_cuda_function_in_microseconds(f, *args):
135+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
136+
137+
138+
def main():
139+
torch.random.manual_seed(123)
140+
configs = get_configs()
141+
results = []
142+
for config in tqdm(configs):
143+
result = run_experiment(config)
144+
results.append(Experiment(config=config, result=result))
145+
146+
# Use Tabulate to print results
147+
print_results(results)
148+
149+
150+
if __name__ == "__main__":
151+
main()

torchao/prototype/moe_training/kernels/float8_rowwise.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,18 @@
2626
torch.float64: tl.float64,
2727
}
2828

29-
block_sizes = [16]
30-
num_warps = [4]
31-
num_stages = [2]
29+
block_sizes_n = [32, 128, 512] # large dim (output_features)
30+
block_sizes_k = [32, 128, 512] # small dim (input_features)
31+
num_warps = [8]
32+
num_stages = [2, 3]
3233
kernel_configs_2D = [
3334
triton.Config(
34-
{"BLOCK_SIZE_N": block_size, "BLOCK_SIZE_K": block_size * 2},
35+
{"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k},
3536
num_warps=warps,
3637
num_stages=stages,
3738
)
38-
for block_size in block_sizes
39+
for block_size_n in block_sizes_n
40+
for block_size_k in block_sizes_k
3941
for warps in num_warps
4042
for stages in num_stages
4143
]
@@ -62,8 +64,9 @@ def triton_fp8_rowwise_3d_transpose_rhs(
6264

6365
# allocate on-device buffers for output and scales
6466
# output shape = input.transpose(-2, -1).shape = (E, N, K) in column major layout
65-
output_buffer = torch.empty((e, k, n), dtype=output_dtype, device=hp_tensor.device)
66-
output_buffer = output_buffer.transpose(-2, -1)
67+
output_buffer = torch.empty_like(
68+
hp_tensor, dtype=output_dtype, device=hp_tensor.device
69+
)
6770
scales_buffer = torch.full(
6871
(e, k), float("inf"), dtype=torch.float32, device=hp_tensor.device
6972
)

torchao/prototype/moe_training/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def torch_to_float8_per_group_rowwise(
146146

147147

148148
def torch_to_3d_rowwise_float8_transpose_rhs(
149-
input_hp: torch.Tensor, # (E, K, N)
149+
input_hp_t: torch.Tensor, # (E, K, N)
150150
target_dtype: torch.dtype = torch.float8_e4m3fn,
151151
round_scales_to_power_of_2: bool = False,
152152
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -162,17 +162,18 @@ def torch_to_3d_rowwise_float8_transpose_rhs(
162162
Output shape: (E, N, K)
163163
Scales shape: (E, 1, K
164164
"""
165-
input_hp_t = input_hp.transpose(-2, -1) # (E, N, K)
165+
assert _is_column_major(input_hp_t), "input tensor must be column-major"
166+
input_hp = input_hp_t.transpose(-2, -1) # (E, N, K)
166167
scales = tensor_to_scale(
167-
input_hp_t,
168+
input_hp,
168169
target_dtype,
169170
scaling_granularity=ScalingGranularity.AXISWISE,
170171
axiswise_dim=-2,
171172
round_scales_to_power_of_2=round_scales_to_power_of_2,
172173
) # (E, 1, K)
173174

174175
# Apply scales to tensor and convert to float8.
175-
tensor_scaled = input_hp_t.to(torch.float32) * scales
176+
tensor_scaled = input_hp.to(torch.float32) * scales
176177
float8_tensor = to_fp8_saturated(tensor_scaled, target_dtype)
177178

178179
# To column major

0 commit comments

Comments
 (0)