Skip to content

Commit f6688be

Browse files
[moe training] add fp8 rowwise kernels for expert weights
stack-info: PR: #2696, branch: danielvegamyhre/stack/30
1 parent 221f807 commit f6688be

File tree

5 files changed

+333
-15
lines changed

5 files changed

+333
-15
lines changed

test/prototype/moe_training/test_kernels.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,18 @@
1919
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2020

2121

22+
from torchao.prototype.moe_training.kernels.float8_rowwise import (
23+
triton_fp8_rowwise_3d_transpose_rhs,
24+
)
2225
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
2326
triton_fp8_col_major_jagged_colwise_scales,
2427
triton_fp8_row_major_jagged_rowwise_scales,
2528
)
2629
from torchao.prototype.moe_training.utils import (
2730
_is_column_major,
28-
_to_2d_jagged_float8_tensor_colwise,
29-
_to_2d_jagged_float8_tensor_rowwise,
31+
torch_to_2d_jagged_float8_tensor_colwise,
32+
torch_to_2d_jagged_float8_tensor_rowwise,
33+
torch_to_3d_rowwise_float8_transpose_rhs,
3034
)
3135
from torchao.testing.utils import skip_if_rocm
3236

@@ -42,7 +46,7 @@ def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool):
4246
colwise_offs = torch.arange(k, k * n_groups + 1, k, device=device)
4347

4448
# compute reference with torch impl
45-
ref_fp8_data, ref_scales = _to_2d_jagged_float8_tensor_rowwise(
49+
ref_fp8_data, ref_scales = torch_to_2d_jagged_float8_tensor_rowwise(
4650
x,
4751
colwise_offs,
4852
target_dtype=torch.float8_e4m3fn,
@@ -70,7 +74,7 @@ def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: boo
7074
rowwise_offs = torch.arange(m, m * n_groups + 1, m, device=device)
7175

7276
# compute reference with torch impl
73-
ref_fp8_data, ref_scales = _to_2d_jagged_float8_tensor_colwise(
77+
ref_fp8_data, ref_scales = torch_to_2d_jagged_float8_tensor_colwise(
7478
x,
7579
rowwise_offs,
7680
target_dtype=torch.float8_e4m3fn,
@@ -85,3 +89,34 @@ def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: boo
8589
assert torch.eq(ref_fp8_data, kernel_fp8_data).all(), "fp8 data not equal"
8690
assert torch.eq(ref_scales, kernel_scales).all(), "scales not equal"
8791
assert _is_column_major(kernel_fp8_data), "fp8 data is not column major"
92+
93+
94+
@skip_if_rocm("ROCm not supported")
95+
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
96+
def test_fp8_rowwise_3d_transpose_rhs(round_scales_to_power_of_2: bool):
97+
device = "cuda"
98+
experts, n, k = 8, 4 * 5120, 5120
99+
100+
# Example expert weights as it comes into forward transposed
101+
torch.manual_seed(0)
102+
x = torch.randn((experts, n, k), dtype=torch.bfloat16, device=device).transpose(
103+
-2, -1
104+
)
105+
106+
# Compute reference with torch impl
107+
ref_fp8, ref_scales = torch_to_3d_rowwise_float8_transpose_rhs(
108+
x,
109+
target_dtype=torch.float8_e4m3fn,
110+
round_scales_to_power_of_2=round_scales_to_power_of_2,
111+
)
112+
# Pytorch impl keeps the empty scaled dimension, so we need to squeeze it out.
113+
ref_scales = ref_scales.squeeze(1)
114+
115+
triton_fp8, triton_scales = triton_fp8_rowwise_3d_transpose_rhs(
116+
x,
117+
output_dtype=torch.float8_e4m3fn,
118+
round_scales_to_power_of_2=round_scales_to_power_of_2,
119+
)
120+
121+
assert torch.allclose(ref_scales, triton_scales, rtol=0, atol=0), "scales not equal"
122+
assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal"

torchao/prototype/moe_training/benchmarks/benchmark_kernels.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
triton_fp8_row_major_jagged_rowwise_scales,
2020
)
2121
from torchao.prototype.moe_training.utils import (
22-
_to_2d_jagged_float8_tensor_colwise,
23-
_to_2d_jagged_float8_tensor_rowwise,
22+
torch_to_2d_jagged_float8_tensor_colwise,
23+
torch_to_2d_jagged_float8_tensor_rowwise,
2424
)
2525

2626
device = torch.device("cuda")
@@ -98,13 +98,13 @@ def warmup(func, *args, **kwargs):
9898
def run_torch(
9999
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
100100
):
101-
_ = _to_2d_jagged_float8_tensor_rowwise(
101+
_ = torch_to_2d_jagged_float8_tensor_rowwise(
102102
input_row_major,
103103
offs,
104104
target_dtype=torch.float8_e4m3fn,
105105
round_scales_to_power_of_2=True,
106106
)
107-
_ = _to_2d_jagged_float8_tensor_colwise(
107+
_ = torch_to_2d_jagged_float8_tensor_colwise(
108108
input_col_major,
109109
offs,
110110
target_dtype=torch.float8_e4m3fn,
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import Tuple
9+
10+
import torch
11+
import triton
12+
import triton.language as tl
13+
14+
EPS = 1e-12
15+
16+
FP8_DTYPE_MAP = {
17+
torch.int8: tl.int8,
18+
torch.int16: tl.int16,
19+
torch.int32: tl.int32,
20+
torch.int64: tl.int64,
21+
torch.float8_e4m3fn: tl.float8e4nv,
22+
torch.float8_e5m2: tl.float8e5,
23+
torch.float16: tl.float16,
24+
torch.bfloat16: tl.bfloat16,
25+
torch.float32: tl.float32,
26+
torch.float64: tl.float64,
27+
}
28+
29+
block_sizes = [16]
30+
num_warps = [4]
31+
num_stages = [2]
32+
kernel_configs_2D = [
33+
triton.Config(
34+
{"BLOCK_SIZE_N": block_size, "BLOCK_SIZE_K": block_size * 2},
35+
num_warps=warps,
36+
num_stages=stages,
37+
)
38+
for block_size in block_sizes
39+
for warps in num_warps
40+
for stages in num_stages
41+
]
42+
43+
from torch.library import triton_op, wrap_triton
44+
45+
46+
@triton_op("torchao::triton_fp8_rowwise_transpose_rhs", mutates_args={})
47+
def triton_fp8_rowwise_3d_transpose_rhs(
48+
hp_tensor: torch.Tensor, # (E, K, N)
49+
output_dtype: torch.dtype = torch.float8_e4m3fn,
50+
round_scales_to_power_of_2: bool = False,
51+
) -> Tuple[torch.Tensor, torch.Tensor]:
52+
assert hp_tensor.ndim == 3, "input tensor must be 3D"
53+
54+
num_elements = hp_tensor.numel()
55+
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
56+
tl_output_dtype = FP8_DTYPE_MAP[output_dtype]
57+
58+
fp8_dtype_min = torch.finfo(output_dtype).min
59+
fp8_dtype_max = torch.finfo(output_dtype).max
60+
61+
e, k, n = hp_tensor.shape
62+
63+
# allocate on-device buffers for output and scales
64+
# output shape = input.transpose(-2, -1).shape = (E, N, K) in column major layout
65+
output_buffer = torch.empty((e, k, n), dtype=output_dtype, device=hp_tensor.device)
66+
output_buffer = output_buffer.transpose(-2, -1)
67+
scales_buffer = torch.full(
68+
(e, k), float("inf"), dtype=torch.float32, device=hp_tensor.device
69+
)
70+
71+
# parallelize across experts, and for each expert, parallelize across rows and cols
72+
grid = lambda meta: (
73+
e,
74+
triton.cdiv(k, meta["BLOCK_SIZE_K"]),
75+
triton.cdiv(n, meta["BLOCK_SIZE_N"]),
76+
)
77+
78+
# compute scales
79+
wrap_triton(_triton_fp8_rowwise_3d_transpose_scales_rhs_kernel)[grid](
80+
hp_tensor,
81+
hp_tensor.stride(0),
82+
hp_tensor.stride(1),
83+
hp_tensor.stride(2),
84+
scales_buffer,
85+
scales_buffer.stride(0),
86+
scales_buffer.stride(1),
87+
e,
88+
n,
89+
k,
90+
num_elements,
91+
fp8_dtype_min,
92+
fp8_dtype_max,
93+
tl_input_dtype,
94+
round_scales_to_power_of_2=round_scales_to_power_of_2,
95+
EPS=EPS,
96+
)
97+
98+
# perform casting
99+
wrap_triton(_triton_fp8_rowwise_3d_transpose_cast_rhs_kernel)[grid](
100+
hp_tensor,
101+
hp_tensor.stride(0),
102+
hp_tensor.stride(1),
103+
hp_tensor.stride(2),
104+
output_buffer,
105+
output_buffer.stride(0),
106+
output_buffer.stride(1),
107+
output_buffer.stride(2),
108+
scales_buffer,
109+
scales_buffer.stride(0),
110+
scales_buffer.stride(1),
111+
e,
112+
n,
113+
k,
114+
num_elements,
115+
fp8_dtype_min,
116+
fp8_dtype_max,
117+
tl_input_dtype,
118+
tl_output_dtype,
119+
)
120+
return output_buffer, scales_buffer
121+
122+
123+
@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
124+
@triton.jit
125+
def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel(
126+
input_ptr,
127+
stride_input_dim0: int,
128+
stride_input_dim1: int,
129+
stride_input_dim2: int,
130+
scales_ptr,
131+
stride_scales_dim0: int,
132+
stride_scales_dim1: int,
133+
E: int,
134+
N: int,
135+
K: int,
136+
num_elements: int,
137+
fp8_dtype_min: tl.constexpr,
138+
fp8_dtype_max: tl.constexpr,
139+
input_dtype: tl.constexpr,
140+
round_scales_to_power_of_2: tl.constexpr,
141+
BLOCK_SIZE_N: tl.constexpr,
142+
BLOCK_SIZE_K: tl.constexpr,
143+
EPS: tl.constexpr,
144+
):
145+
# parallelize across experts, rows, and cols
146+
expert_idx = tl.program_id(0)
147+
k_block_idx = tl.program_id(1)
148+
n_block_idx = tl.program_id(2)
149+
150+
# compute offsets for each dimension
151+
k_offs = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
152+
n_offs = n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
153+
154+
# load block of input data, shape (K, N)
155+
input_offs = (
156+
expert_idx * stride_input_dim0
157+
+ k_offs[:, None] * stride_input_dim1
158+
+ (n_offs[None, :] * stride_input_dim2)
159+
)
160+
input_mask = (k_offs[:, None] < K) & (n_offs[None, :] < N)
161+
input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0).to(
162+
input_dtype
163+
)
164+
165+
# compute scales with local amax, using axis=0 because for each expert,
166+
# we are reading the non-transposed input, and want to compute the scales
167+
# along axis=1 for the transposed input.
168+
amaxes = tl.max(tl.abs(input_data), axis=1).to(tl.float64) # (K,)
169+
scales = (fp8_dtype_max / tl.clamp(amaxes, min=EPS, max=float("inf"))).to(
170+
tl.float32
171+
)
172+
if round_scales_to_power_of_2:
173+
scales = tl.exp2(tl.floor(tl.log2(scales)))
174+
175+
# compute global scales using atomics with local scales - shape (1, K)
176+
scales_offs = (
177+
expert_idx[:, None] * stride_scales_dim0 + k_offs[None, :] * stride_scales_dim1
178+
)
179+
scales_mask = k_offs[None, :] < K
180+
tl.atomic_min(scales_ptr + scales_offs, scales[None, :], mask=scales_mask)
181+
182+
183+
@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
184+
@triton.jit
185+
def _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel(
186+
input_ptr,
187+
stride_input_dim0: int,
188+
stride_input_dim1: int,
189+
stride_input_dim2: int,
190+
output_ptr,
191+
stride_output_dim0: int,
192+
stride_output_dim1: int,
193+
stride_output_dim2: int,
194+
scales_ptr,
195+
stride_scales_dim0: int,
196+
stride_scales_dim1: int,
197+
E: int,
198+
N: int,
199+
K: int,
200+
num_elements: int,
201+
fp8_dtype_min: tl.constexpr,
202+
fp8_dtype_max: tl.constexpr,
203+
input_dtype: tl.constexpr,
204+
output_dtype: tl.constexpr,
205+
BLOCK_SIZE_N: tl.constexpr,
206+
BLOCK_SIZE_K: tl.constexpr,
207+
):
208+
# parallelize across experts, rows, and cols
209+
expert_idx = tl.program_id(0)
210+
k_block_idx = tl.program_id(1)
211+
n_block_idx = tl.program_id(2)
212+
213+
# compute offsets for each dimension
214+
k_offs = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
215+
n_offs = n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
216+
217+
# load block of input data for this expert - shape (K, N)
218+
input_offs = (
219+
expert_idx * stride_input_dim0
220+
+ k_offs[:, None] * stride_input_dim1
221+
+ (n_offs[None, :] * stride_input_dim2)
222+
)
223+
input_mask = (k_offs[:, None] < K) & (n_offs[None, :] < N)
224+
input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0).to(
225+
input_dtype
226+
)
227+
input_data = input_data.trans(1, 0) # (K, N) -> (N, K)
228+
229+
# load global scales for this block of the given expert - shape (1, K)
230+
scales_offs = (
231+
expert_idx[:, None] * stride_scales_dim0 + k_offs[None, :] * stride_scales_dim1
232+
)
233+
scales_mask = k_offs[None, :] < K
234+
scales = tl.load(scales_ptr + scales_offs, mask=scales_mask, other=0.0).to(
235+
tl.float32
236+
)
237+
238+
# transpose data and apply scales - shape (N,K) * (1,K) = (N,K)
239+
scaled_data = input_data * scales
240+
output_data = tl.clamp(scaled_data, min=fp8_dtype_min, max=fp8_dtype_max).to(
241+
output_dtype
242+
)
243+
244+
# store transpose and store output data - shape (N, K)
245+
output_offs = (
246+
expert_idx * stride_output_dim0
247+
+ n_offs[:, None] * stride_output_dim1
248+
+ (k_offs[None, :] * stride_output_dim2)
249+
)
250+
output_mask = (n_offs[:, None] < N) & (k_offs[None, :] < K)
251+
tl.store(output_ptr + output_offs, output_data, mask=output_mask)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ def forward(
111111
A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)
112112

113113
# Convert B to float8, column-major for right operand of grouped GEMM.
114-
# B shape: (E, K, N)
115-
# B scales must be computed rowwise keeping the outer/final dim, so:
116-
# B_scales shape: (E, 1, N)
114+
# B_t shape: (E, K, N)
115+
# B_t scales must be computed rowwise keeping the outer/final dim, so:
116+
# B_t_scales shape: (E, 1, N)
117117
B_t_scales = tensor_to_scale(
118118
B_t,
119119
torch.float8_e4m3fn,
@@ -129,9 +129,9 @@ def forward(
129129
# In the backward this is needed for grad_A: grad_output @ B.
130130
B = B_t.contiguous().transpose(-2, -1)
131131

132-
# - B shape: (E, K, N)
132+
# - B shape: (E, N, K)
133133
# - B scales must be computed rowwise keeping the outer/final dim, so:
134-
# - B_scale shape: (E, 1, N)
134+
# - B_scale shape: (E, 1, K)
135135
B_scales = tensor_to_scale(
136136
B,
137137
torch.float8_e4m3fn,

0 commit comments

Comments
 (0)