Skip to content

Commit 03c179b

Browse files
[moe training] use smaller block sizes for per group scaling kernels to improve perf
1 parent 1f0d2bb commit 03c179b

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

torchao/prototype/moe_training/benchmarks/benchmark_kernels.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
77

88
import itertools
9-
import time
109
from dataclasses import dataclass
1110
from typing import List
1211

1312
import torch
1413
from tabulate import tabulate
1514
from tqdm import tqdm
15+
from triton.testing import do_bench
1616

1717
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
1818
triton_fp8_col_major_jagged_colwise_scales,
@@ -129,18 +129,15 @@ def run_triton(
129129

130130
# bench torch
131131
compiled_run_torch = torch.compile(run_torch)
132-
warmup(compiled_run_torch, input_row_major, input_col_major, offs)
133-
start_time_ns = time.perf_counter_ns()
134-
compiled_run_torch(input_row_major, input_col_major, offs)
135-
torch_time_ns = time.perf_counter_ns() - start_time_ns
136-
torch_time_us = torch_time_ns / 1e3
132+
torch_time_us = benchmark_cuda_function_in_microseconds(
133+
compiled_run_torch, input_row_major, input_col_major, offs
134+
)
137135

138136
# bench triton
139137
warmup(run_triton, input_row_major, input_col_major, offs)
140-
start_time_ns = time.perf_counter_ns()
141-
run_triton(input_row_major, input_col_major, offs)
142-
triton_time_ns = time.perf_counter_ns() - start_time_ns
143-
triton_time_us = triton_time_ns / 1e3
138+
triton_time_us = benchmark_cuda_function_in_microseconds(
139+
run_triton, input_row_major, input_col_major, offs
140+
)
144141

145142
return ExperimentResult(
146143
torch_time_us=torch_time_us,
@@ -173,6 +170,10 @@ def print_results(experiments: List[Experiment]):
173170
print(tabulate(rows, headers=headers))
174171

175172

173+
def benchmark_cuda_function_in_microseconds(f, *args):
174+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
175+
176+
176177
def main():
177178
torch.random.manual_seed(123)
178179
configs = get_configs()

torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
torch.float64: tl.float64,
3434
}
3535

36-
block_sizes = [128, 256]
36+
block_sizes = [16, 32]
3737
kernel_configs_2D = [
3838
triton.Config(
3939
{"BLOCK_SIZE_ROWS": block_size_rows, "BLOCK_SIZE_COLS": block_size_cols}

0 commit comments

Comments
 (0)