Skip to content

[moe training] add fp8 rowwise kernels for expert weights #2696

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmarks/prototype/moe_training/benchmark_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
triton_fp8_row_major_jagged_rowwise_scales,
)
from torchao.prototype.moe_training.utils import (
_to_2d_jagged_float8_tensor_colwise,
_to_2d_jagged_float8_tensor_rowwise,
torch_to_float8_per_group_colwise,
torch_to_float8_per_group_rowwise,
)

device = torch.device("cuda")
Expand Down Expand Up @@ -98,13 +98,13 @@ def warmup(func, *args, **kwargs):
def run_torch(
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
):
_ = _to_2d_jagged_float8_tensor_rowwise(
_ = torch_to_float8_per_group_rowwise(
input_row_major,
offs,
target_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
_ = _to_2d_jagged_float8_tensor_colwise(
_ = torch_to_float8_per_group_colwise(
input_col_major,
offs,
target_dtype=torch.float8_e4m3fn,
Expand Down
47 changes: 43 additions & 4 deletions test/prototype/moe_training/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


from torchao.prototype.moe_training.kernels.float8_rowwise import (
triton_fp8_rowwise_3d_transpose_rhs,
)
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
triton_fp8_col_major_jagged_colwise_scales,
triton_fp8_row_major_jagged_rowwise_scales,
)
from torchao.prototype.moe_training.utils import (
_is_column_major,
_to_2d_jagged_float8_tensor_colwise,
_to_2d_jagged_float8_tensor_rowwise,
torch_to_3d_rowwise_float8_transpose_rhs,
torch_to_float8_per_group_colwise,
torch_to_float8_per_group_rowwise,
)
from torchao.testing.utils import skip_if_rocm

Expand All @@ -42,7 +46,7 @@ def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool):
colwise_offs = torch.arange(k, k * n_groups + 1, k, device=device)

# compute reference with torch impl
ref_fp8_data, ref_scales = _to_2d_jagged_float8_tensor_rowwise(
ref_fp8_data, ref_scales = torch_to_float8_per_group_rowwise(
x,
colwise_offs,
target_dtype=torch.float8_e4m3fn,
Expand Down Expand Up @@ -70,7 +74,7 @@ def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: boo
rowwise_offs = torch.arange(m, m * n_groups + 1, m, device=device)

# compute reference with torch impl
ref_fp8_data, ref_scales = _to_2d_jagged_float8_tensor_colwise(
ref_fp8_data, ref_scales = torch_to_float8_per_group_colwise(
x,
rowwise_offs,
target_dtype=torch.float8_e4m3fn,
Expand All @@ -85,3 +89,38 @@ def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: boo
assert torch.eq(ref_fp8_data, kernel_fp8_data).all(), "fp8 data not equal"
assert torch.eq(ref_scales, kernel_scales).all(), "scales not equal"
assert _is_column_major(kernel_fp8_data), "fp8 data is not column major"


@skip_if_rocm("ROCm not supported")
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
def test_fp8_rowwise_3d_transpose_rhs(round_scales_to_power_of_2: bool):
device = "cuda"
experts, n, k = 8, 4 * 5120, 5120

# Example expert weights as it comes into forward transposed
torch.manual_seed(0)
x = torch.randn((experts, n, k), dtype=torch.bfloat16, device=device).transpose(
-2, -1
)

# Compute reference with torch impl
ref_fp8, ref_scales = torch_to_3d_rowwise_float8_transpose_rhs(
x,
target_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=round_scales_to_power_of_2,
)
# Torch impl keeps empty scaled dim, so we squeeze it out to be consistent with triton impl
ref_scales = ref_scales.squeeze(1)

triton_fp8, triton_scales = triton_fp8_rowwise_3d_transpose_rhs(
x,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=round_scales_to_power_of_2,
)
assert ref_scales.shape == triton_scales.shape, "scale shapes not equal"
assert ref_scales.stride() == triton_scales.stride(), "scale strides not equal"
assert torch.allclose(ref_scales, triton_scales, rtol=0, atol=0), "scales not equal"

assert ref_fp8.shape == triton_fp8.shape, "output shapes not equal"
assert ref_fp8.stride() == triton_fp8.stride(), "output strides not equal"
assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal"
251 changes: 251 additions & 0 deletions torchao/prototype/moe_training/kernels/float8_rowwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.


from typing import Tuple

import torch
import triton
import triton.language as tl

EPS = 1e-12

FP8_DTYPE_MAP = {
torch.int8: tl.int8,
torch.int16: tl.int16,
torch.int32: tl.int32,
torch.int64: tl.int64,
torch.float8_e4m3fn: tl.float8e4nv,
torch.float8_e5m2: tl.float8e5,
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32,
torch.float64: tl.float64,
}

block_sizes = [16]
num_warps = [4]
num_stages = [2]
kernel_configs_2D = [
triton.Config(
{"BLOCK_SIZE_N": block_size, "BLOCK_SIZE_K": block_size * 2},
num_warps=warps,
num_stages=stages,
)
for block_size in block_sizes
for warps in num_warps
for stages in num_stages
]

from torch.library import triton_op, wrap_triton


@triton_op("torchao::triton_fp8_rowwise_transpose_rhs", mutates_args={})
def triton_fp8_rowwise_3d_transpose_rhs(
hp_tensor: torch.Tensor, # (E, K, N)
output_dtype: torch.dtype = torch.float8_e4m3fn,
round_scales_to_power_of_2: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert hp_tensor.ndim == 3, "input tensor must be 3D"

num_elements = hp_tensor.numel()
tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
tl_output_dtype = FP8_DTYPE_MAP[output_dtype]

fp8_dtype_min = torch.finfo(output_dtype).min
fp8_dtype_max = torch.finfo(output_dtype).max

e, k, n = hp_tensor.shape

# allocate on-device buffers for output and scales
# output shape = input.transpose(-2, -1).shape = (E, N, K) in column major layout
output_buffer = torch.empty((e, k, n), dtype=output_dtype, device=hp_tensor.device)
output_buffer = output_buffer.transpose(-2, -1)
scales_buffer = torch.full(
(e, k), float("inf"), dtype=torch.float32, device=hp_tensor.device
)

# parallelize across experts, and for each expert, parallelize across rows and cols
grid = lambda meta: (
e,
triton.cdiv(k, meta["BLOCK_SIZE_K"]),
triton.cdiv(n, meta["BLOCK_SIZE_N"]),
)

# compute scales
wrap_triton(_triton_fp8_rowwise_3d_transpose_scales_rhs_kernel)[grid](
hp_tensor,
hp_tensor.stride(0),
hp_tensor.stride(1),
hp_tensor.stride(2),
scales_buffer,
scales_buffer.stride(0),
scales_buffer.stride(1),
e,
n,
k,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
tl_input_dtype,
round_scales_to_power_of_2=round_scales_to_power_of_2,
EPS=EPS,
)

# perform casting
wrap_triton(_triton_fp8_rowwise_3d_transpose_cast_rhs_kernel)[grid](
hp_tensor,
hp_tensor.stride(0),
hp_tensor.stride(1),
hp_tensor.stride(2),
output_buffer,
output_buffer.stride(0),
output_buffer.stride(1),
output_buffer.stride(2),
scales_buffer,
scales_buffer.stride(0),
scales_buffer.stride(1),
e,
n,
k,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
tl_input_dtype,
tl_output_dtype,
)
return output_buffer, scales_buffer


@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
@triton.jit
def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel(
input_ptr,
stride_input_dim0: int,
stride_input_dim1: int,
stride_input_dim2: int,
scales_ptr,
stride_scales_dim0: int,
stride_scales_dim1: int,
E: int,
N: int,
K: int,
num_elements: int,
fp8_dtype_min: tl.constexpr,
fp8_dtype_max: tl.constexpr,
input_dtype: tl.constexpr,
round_scales_to_power_of_2: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
EPS: tl.constexpr,
):
# parallelize across experts, rows, and cols
expert_idx = tl.program_id(0)
k_block_idx = tl.program_id(1)
n_block_idx = tl.program_id(2)

# compute offsets for each dimension
k_offs = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
n_offs = n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

# load block of input data, shape (K, N)
input_offs = (
expert_idx * stride_input_dim0
+ k_offs[:, None] * stride_input_dim1
+ (n_offs[None, :] * stride_input_dim2)
)
input_mask = (k_offs[:, None] < K) & (n_offs[None, :] < N)
input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0).to(
input_dtype
)

# compute scales with local amax, using axis=0 because for each expert,
# we are reading the non-transposed input, and want to compute the scales
# along axis=1 for the transposed input.
amaxes = tl.max(tl.abs(input_data), axis=1).to(tl.float64) # (K,)
scales = (fp8_dtype_max / tl.clamp(amaxes, min=EPS, max=float("inf"))).to(
tl.float32
)
if round_scales_to_power_of_2:
scales = tl.exp2(tl.floor(tl.log2(scales)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems expensive, can we just extract the bits?


# compute global scales using atomics with local scales - shape (1, K)
scales_offs = (
expert_idx[:, None] * stride_scales_dim0 + k_offs[None, :] * stride_scales_dim1
)
scales_mask = k_offs[None, :] < K
tl.atomic_min(scales_ptr + scales_offs, scales[None, :], mask=scales_mask)


@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
@triton.jit
def _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel(
input_ptr,
stride_input_dim0: int,
stride_input_dim1: int,
stride_input_dim2: int,
output_ptr,
stride_output_dim0: int,
stride_output_dim1: int,
stride_output_dim2: int,
scales_ptr,
stride_scales_dim0: int,
stride_scales_dim1: int,
E: int,
N: int,
K: int,
num_elements: int,
fp8_dtype_min: tl.constexpr,
fp8_dtype_max: tl.constexpr,
input_dtype: tl.constexpr,
output_dtype: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
# parallelize across experts, rows, and cols
expert_idx = tl.program_id(0)
k_block_idx = tl.program_id(1)
n_block_idx = tl.program_id(2)

# compute offsets for each dimension
k_offs = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
n_offs = n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

# load block of input data for this expert - shape (K, N)
input_offs = (
expert_idx * stride_input_dim0
+ k_offs[:, None] * stride_input_dim1
+ (n_offs[None, :] * stride_input_dim2)
)
input_mask = (k_offs[:, None] < K) & (n_offs[None, :] < N)
input_data = tl.load(input_ptr + input_offs, mask=input_mask, other=0.0).to(
input_dtype
)
input_data = input_data.trans(1, 0) # (K, N) -> (N, K)

# load global scales for this block of the given expert - shape (1, K)
scales_offs = (
expert_idx[:, None] * stride_scales_dim0 + k_offs[None, :] * stride_scales_dim1
)
scales_mask = k_offs[None, :] < K
scales = tl.load(scales_ptr + scales_offs, mask=scales_mask, other=0.0).to(
tl.float32
)

# transpose data and apply scales - shape (N,K) * (1,K) = (N,K)
scaled_data = input_data * scales
output_data = tl.clamp(scaled_data, min=fp8_dtype_min, max=fp8_dtype_max).to(
output_dtype
)

# store transpose and store output data - shape (N, K)
output_offs = (
expert_idx * stride_output_dim0
+ n_offs[:, None] * stride_output_dim1
+ (k_offs[None, :] * stride_output_dim2)
)
output_mask = (n_offs[:, None] < N) & (k_offs[None, :] < K)
tl.store(output_ptr + output_offs, output_data, mask=output_mask)
10 changes: 5 additions & 5 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def forward(
A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)

# Convert B to float8, column-major for right operand of grouped GEMM.
# B shape: (E, K, N)
# B scales must be computed rowwise keeping the outer/final dim, so:
# B_scales shape: (E, 1, N)
# B_t shape: (E, K, N)
# B_t scales must be computed rowwise keeping the outer/final dim, so:
# B_t_scales shape: (E, 1, N)
B_t_scales = tensor_to_scale(
B_t,
torch.float8_e4m3fn,
Expand All @@ -144,9 +144,9 @@ def forward(
# In the backward this is needed for grad_A: grad_output @ B.
B = B_t.contiguous().transpose(-2, -1)

# - B shape: (E, K, N)
# - B shape: (E, N, K)
# - B scales must be computed rowwise keeping the outer/final dim, so:
# - B_scale shape: (E, 1, N)
# - B_scale shape: (E, 1, K)
B_scales = tensor_to_scale(
B,
torch.float8_e4m3fn,
Expand Down
Loading