Skip to content

Commit 351a2a9

Browse files
[moe training] use smaller block sizes for per group scaling kernels to improve perf
stack-info: PR: #2668, branch: danielvegamyhre/stack/27
1 parent 18edd01 commit 351a2a9

File tree

3 files changed

+63
-56
lines changed

3 files changed

+63
-56
lines changed

torchao/prototype/moe_training/benchmarks/benchmark_kernels.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
77

88
import itertools
9-
import time
109
from dataclasses import dataclass
1110
from typing import List
1211

1312
import torch
1413
from tabulate import tabulate
1514
from tqdm import tqdm
15+
from triton.testing import do_bench
1616

1717
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
1818
triton_fp8_col_major_jagged_colwise_scales,
@@ -129,18 +129,15 @@ def run_triton(
129129

130130
# bench torch
131131
compiled_run_torch = torch.compile(run_torch)
132-
warmup(compiled_run_torch, input_row_major, input_col_major, offs)
133-
start_time_ns = time.perf_counter_ns()
134-
compiled_run_torch(input_row_major, input_col_major, offs)
135-
torch_time_ns = time.perf_counter_ns() - start_time_ns
136-
torch_time_us = torch_time_ns / 1e3
132+
torch_time_us = benchmark_cuda_function_in_microseconds(
133+
compiled_run_torch, input_row_major, input_col_major, offs
134+
)
137135

138136
# bench triton
139137
warmup(run_triton, input_row_major, input_col_major, offs)
140-
start_time_ns = time.perf_counter_ns()
141-
run_triton(input_row_major, input_col_major, offs)
142-
triton_time_ns = time.perf_counter_ns() - start_time_ns
143-
triton_time_us = triton_time_ns / 1e3
138+
triton_time_us = benchmark_cuda_function_in_microseconds(
139+
run_triton, input_row_major, input_col_major, offs
140+
)
144141

145142
return ExperimentResult(
146143
torch_time_us=torch_time_us,
@@ -173,6 +170,10 @@ def print_results(experiments: List[Experiment]):
173170
print(tabulate(rows, headers=headers))
174171

175172

173+
def benchmark_cuda_function_in_microseconds(f, *args):
174+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
175+
176+
176177
def main():
177178
torch.random.manual_seed(123)
178179
configs = get_configs()

torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import triton
1717
import triton.language as tl
1818

19-
from torchao.prototype.moe_training.utils import _is_column_major
20-
2119
EPS = 1e-12
2220

2321
FP8_DTYPE_MAP = {
@@ -33,13 +31,20 @@
3331
torch.float64: tl.float64,
3432
}
3533

36-
block_sizes = [128, 256]
34+
block_sizes = [1, 16, 32, 64]
35+
block_sizes_iter = [32, 64, 128, 256]
36+
num_warps = [1, 4]
37+
num_stages = [2, 3]
3738
kernel_configs_2D = [
3839
triton.Config(
39-
{"BLOCK_SIZE_ROWS": block_size_rows, "BLOCK_SIZE_COLS": block_size_cols}
40+
{"BLOCK_SIZE": block_size, "BLOCK_SIZE_ITER": block_size_iter},
41+
num_warps=warps,
42+
num_stages=stages,
4043
)
41-
for block_size_rows in block_sizes
42-
for block_size_cols in block_sizes
44+
for block_size in block_sizes
45+
for block_size_iter in block_sizes_iter
46+
for warps in num_warps
47+
for stages in num_stages
4348
]
4449

4550
from torch.library import triton_op, wrap_triton
@@ -68,7 +73,6 @@ def triton_fp8_row_major_jagged_rowwise_scales(
6873
- jagged rowwise scales (i.e., rowwise scales for each group)
6974
"""
7075
assert hp_tensor.ndim == 2, "input tensor must be 2D"
71-
assert hp_tensor.is_contiguous(), "input tensor must be contiguous"
7276

7377
num_elements = hp_tensor.numel()
7478
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
@@ -81,16 +85,14 @@ def triton_fp8_row_major_jagged_rowwise_scales(
8185
n_groups = offsets.numel()
8286

8387
# allocate on-device buffers for output and scales
84-
output_buffer = torch.empty_like(
85-
hp_tensor, dtype=output_dtype, device=hp_tensor.device
86-
)
88+
output_buffer = torch.empty((m, k), dtype=output_dtype, device=hp_tensor.device)
8789
scales_buffer = torch.empty(
8890
(m * n_groups), dtype=torch.float32, device=hp_tensor.device
8991
)
9092

9193
# parallelize across rows and groups (offsets)
9294
grid = lambda meta: (
93-
triton.cdiv(m, meta["BLOCK_SIZE_ROWS"]),
95+
triton.cdiv(m, meta["BLOCK_SIZE"]),
9496
offsets.numel(),
9597
)
9698
wrap_triton(_triton_fp8_row_major_jagged_rowwise_scales)[grid](
@@ -115,7 +117,13 @@ def triton_fp8_row_major_jagged_rowwise_scales(
115117
return output_buffer, scales_buffer
116118

117119

118-
@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
120+
# This kernel is used on grad_output.t() which has shape (K, M),
121+
# before the calculation `grad_B = grad_output_t @ input`.
122+
# However, in this code, we use the conventional dim names (M, K)
123+
# so the kernel is easily interpretable in a standalone fasion.
124+
# The tokens per expert will vary per iteration, so don't want
125+
# to recompile on `token` dim (K, in this case) changes.
126+
@triton.autotune(configs=kernel_configs_2D, key=["M"])
119127
@triton.jit
120128
def _triton_fp8_row_major_jagged_rowwise_scales(
121129
input_ptr,
@@ -134,8 +142,8 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
134142
input_dtype: tl.constexpr,
135143
output_dtype: tl.constexpr,
136144
round_scales_to_power_of_2: tl.constexpr,
137-
BLOCK_SIZE_ROWS: tl.constexpr,
138-
BLOCK_SIZE_COLS: tl.constexpr,
145+
BLOCK_SIZE: tl.constexpr,
146+
BLOCK_SIZE_ITER: tl.constexpr,
139147
EPS: tl.constexpr,
140148
):
141149
# parallel across rows and groups (offsets)
@@ -147,12 +155,12 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
147155
offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0
148156
)
149157
group_col_end_idx = tl.load(offsets_ptr + offset_idx)
150-
block_row_offs = block_row_id * BLOCK_SIZE_ROWS + tl.arange(0, BLOCK_SIZE_ROWS)
158+
block_row_offs = block_row_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
151159

152160
# compute rowwise amaxes for this group
153-
amax_buffer = tl.zeros((BLOCK_SIZE_ROWS,), dtype=input_dtype)
154-
for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS):
155-
block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS)
161+
amax_buffer = tl.zeros((BLOCK_SIZE,), dtype=input_dtype)
162+
for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_ITER):
163+
block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_ITER)
156164
block_offs = (
157165
block_row_offs[:, None] * stride_input_row
158166
+ block_col_offs[None, :] * stride_input_col
@@ -180,12 +188,12 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
180188
# store rowwise scales for each group in contiguous memory:
181189
# [group0_row0, group_0_row1, ..., group2_row0, group2_row1]
182190
scales_offs = block_row_offs + (M * offset_idx)
183-
scales_mask = tl.arange(0, BLOCK_SIZE_ROWS) < M
191+
scales_mask = tl.arange(0, BLOCK_SIZE) < M
184192
tl.store(scales_ptr + scales_offs, scales, mask=scales_mask)
185193

186194
# perform float8 conversion for this group
187-
for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_COLS):
188-
block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_COLS)
195+
for col_start_idx in range(group_col_start_idx, group_col_end_idx, BLOCK_SIZE_ITER):
196+
block_col_offs = col_start_idx + tl.arange(0, BLOCK_SIZE_ITER)
189197
block_offs = (
190198
block_row_offs[:, None] * stride_input_row
191199
+ block_col_offs[None, :] * stride_input_col
@@ -230,7 +238,6 @@ def triton_fp8_col_major_jagged_colwise_scales(
230238
- jagged column-wise scales (i.e., column-wise scales for each group)
231239
"""
232240
assert hp_tensor.ndim == 2, "input tensor must be 2D"
233-
assert _is_column_major(hp_tensor), "input tensor must be column-major"
234241

235242
num_elements = hp_tensor.numel()
236243
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
@@ -242,17 +249,18 @@ def triton_fp8_col_major_jagged_colwise_scales(
242249
k, n = hp_tensor.shape
243250
n_groups = offsets.numel()
244251

245-
# allocate on-device buffers for output and scales
252+
# Output buffer in column major
246253
output_buffer = torch.empty_like(
247254
hp_tensor, dtype=output_dtype, device=hp_tensor.device
248-
)
255+
).as_strided(hp_tensor.size(), (1, k))
256+
249257
scales_buffer = torch.empty(
250258
(n * n_groups), dtype=torch.float32, device=hp_tensor.device
251259
)
252260

253261
# parallelize across columns and groups (offsets)
254262
grid = lambda meta: (
255-
triton.cdiv(n, meta["BLOCK_SIZE_COLS"]),
263+
triton.cdiv(n, meta["BLOCK_SIZE"]),
256264
offsets.numel(),
257265
)
258266
wrap_triton(_triton_fp8_col_major_jagged_colwise_scales)[grid](
@@ -277,7 +285,11 @@ def triton_fp8_col_major_jagged_colwise_scales(
277285
return output_buffer, scales_buffer
278286

279287

280-
@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
288+
# This kernel is used on `input` which has shape (M, K),
289+
# before the calculation `grad_B = grad_output_t @ input`.
290+
# The tokens per expert will vary per iteration, so don't want
291+
# to recompile on `token` dim (M) changes.
292+
@triton.autotune(configs=kernel_configs_2D, key=["K"])
281293
@triton.jit
282294
def _triton_fp8_col_major_jagged_colwise_scales(
283295
input_ptr,
@@ -296,8 +308,8 @@ def _triton_fp8_col_major_jagged_colwise_scales(
296308
input_dtype: tl.constexpr,
297309
output_dtype: tl.constexpr,
298310
round_scales_to_power_of_2: tl.constexpr,
299-
BLOCK_SIZE_ROWS: tl.constexpr,
300-
BLOCK_SIZE_COLS: tl.constexpr,
311+
BLOCK_SIZE: tl.constexpr,
312+
BLOCK_SIZE_ITER: tl.constexpr,
301313
EPS: tl.constexpr,
302314
):
303315
# parallel across columns and groups (offsets)
@@ -309,12 +321,12 @@ def _triton_fp8_col_major_jagged_colwise_scales(
309321
offsets_ptr + offset_idx - 1, mask=offset_idx > 0, other=0
310322
)
311323
group_row_end_idx = tl.load(offsets_ptr + offset_idx)
312-
block_col_offs = block_col_id * BLOCK_SIZE_COLS + tl.arange(0, BLOCK_SIZE_COLS)
324+
block_col_offs = block_col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
313325

314326
# compute colwise amaxes for this group
315-
amax_buffer = tl.zeros((BLOCK_SIZE_COLS,), dtype=input_dtype)
316-
for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS):
317-
block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS)
327+
amax_buffer = tl.zeros((BLOCK_SIZE,), dtype=input_dtype)
328+
for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ITER):
329+
block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ITER)
318330
block_offs = (
319331
block_row_offs[:, None] * stride_input_row
320332
+ block_col_offs[None, :] * stride_input_col
@@ -343,12 +355,12 @@ def _triton_fp8_col_major_jagged_colwise_scales(
343355
# [group0_col0, group_0_col1, ..., group2_col0, group2_col1]
344356
# note: input tensor is in col-major memory layout.
345357
scales_offs = block_col_offs + (N * offset_idx)
346-
scales_mask = tl.arange(0, BLOCK_SIZE_COLS) < N
358+
scales_mask = tl.arange(0, BLOCK_SIZE) < N
347359
tl.store(scales_ptr + scales_offs, scales, mask=scales_mask)
348360

349361
# perform float8 conversion for this group
350-
for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ROWS):
351-
block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ROWS)
362+
for row_start_idx in range(group_row_start_idx, group_row_end_idx, BLOCK_SIZE_ITER):
363+
block_row_offs = row_start_idx + tl.arange(0, BLOCK_SIZE_ITER)
352364
block_offs = (
353365
block_row_offs[:, None] * stride_input_row
354366
+ block_col_offs[None, :] * stride_input_col

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -217,35 +217,29 @@ def backward(ctx, grad_output: torch.Tensor):
217217
use_fast_accum=True,
218218
)
219219

220-
# Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
221-
# needed for grad_B: grad_output_t @ A
222-
grad_output_t_row_major = grad_output.transpose(-2, -1).contiguous()
223-
224-
# Convert A to float8, column-major for right operand of grouped GEMM:
225-
# needed for grad_B: grad_output @ A
226-
A_col_major = A.transpose(-2, -1).contiguous().transpose(-2, -1)
227-
228220
# grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups."
229221
# Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups.
222+
223+
# Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
224+
# needed for grad_B: grad_output_t @ A
230225
grad_output_t_fp8_row_major, grad_output_t_scales = (
231226
triton_fp8_row_major_jagged_rowwise_scales(
232-
grad_output_t_row_major,
227+
grad_output.transpose(-2, -1),
233228
offs,
234229
torch.float8_e4m3fn,
235230
round_scales_to_power_of_2=True,
236231
)
237232
)
238233

239234
A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales(
240-
A_col_major,
235+
A,
241236
offs,
242237
torch.float8_e4m3fn,
243238
round_scales_to_power_of_2=True,
244239
)
245240

246241
# Compute grad_B = grad_output_t @ A.
247242
# grad_B = grad_output_t @ A
248-
# grad_B = (N,M) @ (M,K) = (N,K)
249243
assert not _is_column_major(grad_output_t_fp8_row_major), (
250244
"grad_output_t must be row-major for grad_B = grad_output_t @ A"
251245
)

0 commit comments

Comments
 (0)