Skip to content

backward pass for differentiable mxfp8 grouped gemm with dynamic quant #2639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions test/prototype/moe_training/test_scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import torch
from torch.nn import functional as F

pytest.importorskip("triton", reason="Triton required to run this test")

Expand Down Expand Up @@ -306,19 +307,47 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):

@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
@pytest.mark.parametrize("num_experts", (1, 8, 16))
def test_mxfp8_grouped_gemm_with_dq_fwd(M, K, N, num_experts):
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
from torchao.prototype.moe_training.scaled_grouped_mm import (
_MXFP8GroupedMM,
)

x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
offs = generate_jagged_offs(num_experts, M)
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()
block_size = 32
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
w_t = torch.randn(
num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
x_ref, w_t_ref, offs_ref = (
x.clone().detach().requires_grad_(True),
w_t.clone().detach().requires_grad_(True),
offs.clone(),
)

# Forward
out = _MXFP8GroupedMM.apply(x, w_t, offs, block_size, torch.bfloat16)
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
sqnr = compute_error(ref_out, out)
min_sqnr = 27.0
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"
assert sqnr >= min_sqnr, f"Output sqnr {sqnr} is too low, must be >= {min_sqnr}"

# Backward
labels = torch.ones_like(ref_out)
ref_loss = F.mse_loss(ref_out, labels)
out_loss = F.mse_loss(out, labels)
ref_loss.backward()
out_loss.backward()

# Check input grads
min_input_grad_sqnr = 26.0
sqnr = compute_error(x_ref.grad, x.grad)
assert sqnr >= min_input_grad_sqnr, (
f"Input grad sqnr {sqnr} is too low, must be >= {min_input_grad_sqnr}"
)

# Check weight grads
min_weight_grad_sqnr = 24.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lg!

sqnr = compute_error(w_t_ref.grad, w_t.grad)
assert sqnr >= min_weight_grad_sqnr, (
f"Weight grad sqnr {sqnr} is too low, must be >= {min_weight_grad_sqnr}"
)
47 changes: 46 additions & 1 deletion torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
)
from torchao.prototype.moe_training.utils import (
_is_column_major,
_to_mxfp8_per_group_colwise,
_to_mxfp8_per_group_rowwise,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx

Expand Down Expand Up @@ -319,7 +321,50 @@ def forward(

@staticmethod
def backward(ctx, grad_out: torch.Tensor):
raise NotImplementedError
A, B_t, offs = ctx.saved_tensors
block_size = ctx.block_size
out_dtype = ctx.out_dtype
# Compute grad_A.
# grad_A = grad_output @ B
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
grad_out_scale, grad_out_mx = to_mx(
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
)
B_t_scale, B_t_mx = _to_mxfp8_3d_expert_weights_dim1(
B_t.transpose(-2, -1).contiguous(),
block_size=block_size,
elem_dtype=torch.float8_e4m3fn,
)
grad_A = emulated_mxfp8_scaled_grouped_mm(
grad_out_mx,
grad_out_scale,
B_t_mx,
B_t_scale,
offs=offs,
out_dtype=out_dtype,
)
# Compute grad_B = grad_output_t @ A
grad_out_t_mx, grad_out_t_scale = _to_mxfp8_per_group_rowwise(
grad_out.transpose(-2, -1).contiguous(),
offs=offs,
block_size=block_size,
)
A_mx, A_scale = _to_mxfp8_per_group_colwise(
A,
offs=offs,
block_size=block_size,
)
grad_B = emulated_mxfp8_scaled_grouped_mm(
grad_out_t_mx,
grad_out_t_scale,
A_mx,
A_scale,
offs=offs,
)
# In forward we receive pre-transposed weights B_t as input
grad_B_t = grad_B.transpose(-2, -1)

return grad_A, grad_B_t, None, None, None


def _to_mxfp8_3d_expert_weights_dim1(
Expand Down