Skip to content

Commit b3962a0

Browse files
add differentiable mxfp8 grouped gemm with dynamic quant (forward pass)
stack-info: PR: #2627, branch: danielvegamyhre/stack/23
1 parent fa77af5 commit b3962a0

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,25 @@ def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
247247
sqnr = compute_error(ref_out, out)
248248
min_sqnr = 27.0
249249
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"
250+
251+
252+
@pytest.mark.parametrize("M", (1024, 4096))
253+
@pytest.mark.parametrize("K", (1024, 4096))
254+
@pytest.mark.parametrize("N", (1024, 4096))
255+
@pytest.mark.parametrize("num_experts", (1, 8, 16))
256+
def test_mxfp8_grouped_gemm_with_dq_fwd(M, K, N, num_experts):
257+
from torchao.prototype.moe_training.scaled_grouped_mm import (
258+
_MXFP8GroupedMM,
259+
)
260+
261+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
262+
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
263+
offs = generate_jagged_offs(num_experts, M)
264+
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()
265+
block_size = 32
266+
267+
out = _MXFP8GroupedMM.apply(x, w_t, offs, block_size, torch.bfloat16)
268+
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
269+
sqnr = compute_error(ref_out, out)
270+
min_sqnr = 27.0
271+
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
from typing import Optional
8+
from typing import Optional, Tuple
99

1010
import torch
1111

@@ -18,6 +18,7 @@
1818
from torchao.prototype.moe_training.utils import (
1919
_is_column_major,
2020
)
21+
from torchao.prototype.mx_formats.mx_tensor import to_mx
2122

2223
logger: logging.Logger = logging.getLogger(__name__)
2324

@@ -268,6 +269,79 @@ def backward(ctx, grad_output: torch.Tensor):
268269
return grad_A, grad_B.transpose(-2, -1), None, None, None, None
269270

270271

272+
class _MXFP8GroupedMM(torch.autograd.Function):
273+
"""Differentiable implementation of grouped GEMM with dynamic mxpf8 quantization."""
274+
275+
@staticmethod
276+
def forward(
277+
ctx,
278+
A: torch.Tensor,
279+
B_t: torch.Tensor,
280+
offs: Optional[torch.Tensor] = None,
281+
block_size: int = 32,
282+
out_dtype: Optional[torch.dtype] = torch.bfloat16,
283+
) -> torch.Tensor:
284+
# torchao _scaled_grouped_mm only supports A=2D and B=3D.
285+
assert A.ndim == 2, "A must be 2D"
286+
assert B_t.ndim == 3, "B must be 3D"
287+
assert block_size == 32, "Only block_size=32 is supported"
288+
289+
# Cast to mxpf8 across dim -1.
290+
# A_mx shape: (M, K)
291+
# A_scale shape: (M, K//block_size)
292+
A_scale, A_mx = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
293+
294+
# Cast B_t per-expert to mxfp8 across dim1.
295+
# B_t_mx shape: (E, K, N)
296+
# B_t_scale shape: (E, K//block_size, N)
297+
B_t_scale, B_t_mx = _to_mxfp8_3d_expert_weights_dim1(B_t, block_size=block_size)
298+
299+
# Store what we need for backward.
300+
ctx.save_for_backward(A, B_t, offs)
301+
ctx.out_dtype = out_dtype
302+
303+
# Perform scaled grouped GEMM and return result.
304+
# output = input @ weight.T
305+
# output shape: (M, N)
306+
out = emulated_mxfp8_scaled_grouped_mm(
307+
A_mx,
308+
A_scale,
309+
B_t_mx,
310+
B_t_scale,
311+
offs=offs,
312+
block_size=block_size,
313+
out_dtype=out_dtype,
314+
)
315+
return out
316+
317+
@staticmethod
318+
def backward(ctx, grad_output: torch.Tensor):
319+
raise NotImplementedError
320+
321+
322+
def _to_mxfp8_3d_expert_weights_dim1(
323+
w_t: torch.Tensor, # (num_experts, K, N)
324+
block_size: int = 32,
325+
elem_dtype: torch.dtype = torch.float8_e4m3fn,
326+
) -> Tuple[torch.Tensor, torch.Tensor]:
327+
"""Convert a 3D tensor of shape (experts, K, N) to MXFP8 format along dim1.
328+
Args:
329+
x (torch.Tensor): Input tensor to be converted.
330+
block_size (int): Block size for MXFP8 quantization.
331+
elem_dtype (torch.dtype): Element dtype for MXFP8 quantization.
332+
Returns:
333+
Tuple[torch.Tensor, torch.Tensor]: Converted tensor and scale tensor.
334+
- scale shape: (expets, K // block_size, N)
335+
- output shape: (experts, K, N)
336+
"""
337+
# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
338+
w_scale, w_mx = to_mx(
339+
w_t.transpose(-2, -1).contiguous(), elem_dtype=elem_dtype, block_size=block_size
340+
)
341+
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)
342+
return w_t_scale, w_t_mx
343+
344+
271345
def emulated_mxfp8_scaled_grouped_mm(
272346
A_mx: torch.Tensor,
273347
A_scale: torch.Tensor,

0 commit comments

Comments
 (0)