Skip to content

Commit e9ba18b

Browse files
backward pass for differentiable mxfp8 grouped gemm with dynamic quant
stack-info: PR: #2639, branch: danielvegamyhre/stack/25
1 parent f6c2c3a commit e9ba18b

File tree

2 files changed

+81
-7
lines changed

2 files changed

+81
-7
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88
import torch
9+
from torch.nn import functional as F
910

1011
pytest.importorskip("triton", reason="Triton required to run this test")
1112

@@ -306,19 +307,47 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
306307

307308
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
308309
@pytest.mark.parametrize("num_experts", (1, 8, 16))
309-
def test_mxfp8_grouped_gemm_with_dq_fwd(M, K, N, num_experts):
310+
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
310311
from torchao.prototype.moe_training.scaled_grouped_mm import (
311312
_MXFP8GroupedMM,
312313
)
313314

314-
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
315-
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
316-
offs = generate_jagged_offs(num_experts, M)
317-
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()
318315
block_size = 32
316+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
317+
w_t = torch.randn(
318+
num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True
319+
)
320+
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
321+
x_ref, w_t_ref, offs_ref = (
322+
x.clone().detach().requires_grad_(True),
323+
w_t.clone().detach().requires_grad_(True),
324+
offs.clone(),
325+
)
319326

327+
# Forward
320328
out = _MXFP8GroupedMM.apply(x, w_t, offs, block_size, torch.bfloat16)
321329
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
322330
sqnr = compute_error(ref_out, out)
323331
min_sqnr = 27.0
324-
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"
332+
assert sqnr >= min_sqnr, f"Output sqnr {sqnr} is too low, must be >= {min_sqnr}"
333+
334+
# Backward
335+
labels = torch.ones_like(ref_out)
336+
ref_loss = F.mse_loss(ref_out, labels)
337+
out_loss = F.mse_loss(out, labels)
338+
ref_loss.backward()
339+
out_loss.backward()
340+
341+
# Check input grads
342+
min_input_grad_sqnr = 26.0
343+
sqnr = compute_error(x_ref.grad, x.grad)
344+
assert sqnr >= min_input_grad_sqnr, (
345+
f"Input grad sqnr {sqnr} is too low, must be >= {min_input_grad_sqnr}"
346+
)
347+
348+
# Check weight grads
349+
min_weight_grad_sqnr = 24.0
350+
sqnr = compute_error(w_t_ref.grad, w_t.grad)
351+
assert sqnr >= min_weight_grad_sqnr, (
352+
f"Weight grad sqnr {sqnr} is too low, must be >= {min_weight_grad_sqnr}"
353+
)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
)
1818
from torchao.prototype.moe_training.utils import (
1919
_is_column_major,
20+
_to_mxfp8_per_group_colwise,
21+
_to_mxfp8_per_group_rowwise,
2022
)
2123
from torchao.prototype.mx_formats.mx_tensor import to_mx
2224

@@ -319,7 +321,50 @@ def forward(
319321

320322
@staticmethod
321323
def backward(ctx, grad_out: torch.Tensor):
322-
raise NotImplementedError
324+
A, B_t, offs = ctx.saved_tensors
325+
block_size = ctx.block_size
326+
out_dtype = ctx.out_dtype
327+
# Compute grad_A.
328+
# grad_A = grad_output @ B
329+
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
330+
grad_out_scale, grad_out_mx = to_mx(
331+
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
332+
)
333+
B_t_scale, B_t_mx = _to_mxfp8_3d_expert_weights_dim1(
334+
B_t.transpose(-2, -1).contiguous(),
335+
block_size=block_size,
336+
elem_dtype=torch.float8_e4m3fn,
337+
)
338+
grad_A = emulated_mxfp8_scaled_grouped_mm(
339+
grad_out_mx,
340+
grad_out_scale,
341+
B_t_mx,
342+
B_t_scale,
343+
offs=offs,
344+
out_dtype=out_dtype,
345+
)
346+
# Compute grad_B = grad_output_t @ A
347+
grad_out_t_mx, grad_out_t_scale = _to_mxfp8_per_group_rowwise(
348+
grad_out.transpose(-2, -1).contiguous(),
349+
offs=offs,
350+
block_size=block_size,
351+
)
352+
A_mx, A_scale = _to_mxfp8_per_group_colwise(
353+
A,
354+
offs=offs,
355+
block_size=block_size,
356+
)
357+
grad_B = emulated_mxfp8_scaled_grouped_mm(
358+
grad_out_t_mx,
359+
grad_out_t_scale,
360+
A_mx,
361+
A_scale,
362+
offs=offs,
363+
)
364+
# In forward we receive pre-transposed weights B_t as input
365+
grad_B_t = grad_B.transpose(-2, -1)
366+
367+
return grad_A, grad_B_t, None, None, None
323368

324369

325370
def _to_mxfp8_3d_expert_weights_dim1(

0 commit comments

Comments
 (0)