Skip to content

[moe training] integrate rowwise expert quant kernel #2698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchao/prototype/moe_training/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from torchao.prototype.moe_training.kernels.float8_rowwise import (
triton_fp8_rowwise_3d_transpose_rhs as triton_fp8_rowwise_3d_transpose_rhs,
)
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales,
)
Expand Down
16 changes: 4 additions & 12 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchao.prototype.moe_training.kernels import (
triton_fp8_col_major_jagged_colwise_scales,
triton_fp8_row_major_jagged_rowwise_scales,
triton_fp8_rowwise_3d_transpose_rhs,
)
from torchao.prototype.moe_training.utils import (
_is_column_major,
Expand Down Expand Up @@ -142,20 +143,11 @@ def forward(
# Precompute non-transposed B column-major for backward, to save memory by storing the
# low precision B tensor instead of the high precision B tensor.
# In the backward this is needed for grad_A: grad_output @ B.
B = B_t.contiguous().transpose(-2, -1)

# - B shape: (E, N, K)
# - B scales must be computed rowwise keeping the outer/final dim, so:
# - B_scale shape: (E, 1, K)
B_scales = tensor_to_scale(
B,
torch.float8_e4m3fn,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=-2,
B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
B_t,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
B_scaled = B.to(torch.float32) * B_scales
B_fp8_col_major = to_fp8_saturated(B_scaled, torch.float8_e4m3fn)

# Store what we need for backward.
ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs)
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/moe_training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def torch_to_3d_rowwise_float8_transpose_rhs(
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function converts the 3D input tensor to a float8 tensor, with scales computed along logical columns
on a per-expert basis.
on a per-expert basis. Output will be in column-major memory layout.

Args:
x (torch.Tensor): The input tensor to be converted to a float8 tensor. Shape (E, K, N).
Expand Down