Skip to content

Commit fd04d1a

Browse files
add differentiable mxfp8 grouped gemm with dynamic quant (forward pass)
stack-info: PR: #2627, branch: danielvegamyhre/stack/23
1 parent 9834869 commit fd04d1a

File tree

2 files changed

+99
-1
lines changed

2 files changed

+99
-1
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,25 @@ def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
249249
sqnr = compute_error(ref_out, out)
250250
min_sqnr = 27.0
251251
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"
252+
253+
254+
@pytest.mark.parametrize("M", (1024, 4096))
255+
@pytest.mark.parametrize("K", (1024, 4096))
256+
@pytest.mark.parametrize("N", (1024, 4096))
257+
@pytest.mark.parametrize("num_experts", (1, 8, 16))
258+
def test_mxfp8_grouped_gemm_with_dq_fwd(M, K, N, num_experts):
259+
from torchao.prototype.moe_training.scaled_grouped_mm import (
260+
_MXFP8GroupedMM,
261+
)
262+
263+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
264+
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
265+
offs = generate_jagged_offs(num_experts, M)
266+
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()
267+
block_size = 32
268+
269+
out = _MXFP8GroupedMM.apply(x, w_t, offs, block_size, torch.bfloat16)
270+
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
271+
sqnr = compute_error(ref_out, out)
272+
min_sqnr = 27.0
273+
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
from typing import Optional
8+
from typing import Optional, Tuple
99

1010
import torch
1111

@@ -18,6 +18,7 @@
1818
from torchao.prototype.moe_training.utils import (
1919
_is_column_major,
2020
)
21+
from torchao.prototype.mx_formats.mx_tensor import to_mx
2122

2223
logger: logging.Logger = logging.getLogger(__name__)
2324

@@ -268,6 +269,81 @@ def backward(ctx, grad_output: torch.Tensor):
268269
return grad_A, grad_B.transpose(-2, -1), None, None, None, None
269270

270271

272+
class _MXFP8GroupedMM(torch.autograd.Function):
273+
"""Differentiable implementation of grouped GEMM with dynamic mxpf8 quantization."""
274+
275+
@staticmethod
276+
def forward(
277+
ctx,
278+
A: torch.Tensor,
279+
B_t: torch.Tensor,
280+
offs: Optional[torch.Tensor] = None,
281+
block_size: int = 32,
282+
out_dtype: Optional[torch.dtype] = torch.bfloat16,
283+
emulated: bool = True,
284+
) -> torch.Tensor:
285+
# torchao _scaled_grouped_mm only supports A=2D and B=3D.
286+
assert A.ndim == 2, "A must be 2D"
287+
assert B_t.ndim == 3, "B must be 3D"
288+
assert block_size == 32, "Only block_size=32 is supported"
289+
assert emulated, "Only emulated mxfp8 grouped gemm is supported"
290+
291+
# Cast to mxpf8 across dim -1.
292+
# A_mx shape: (M, K)
293+
# A_scale shape: (M, K//block_size)
294+
A_scale, A_mx = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
295+
296+
# Cast B_t per-expert to mxfp8 across dim1.
297+
# B_t_mx shape: (E, K, N)
298+
# B_t_scale shape: (E, K//block_size, N)
299+
B_t_scale, B_t_mx = _to_mxfp8_3d_expert_weights_dim1(B_t, block_size=block_size)
300+
301+
# Store what we need for backward.
302+
ctx.save_for_backward(A, B_t, offs)
303+
ctx.out_dtype = out_dtype
304+
305+
# Perform scaled grouped GEMM and return result.
306+
# output = input @ weight.T
307+
# output shape: (M, N)
308+
out = emulated_mxfp8_scaled_grouped_mm(
309+
A_mx,
310+
A_scale,
311+
B_t_mx,
312+
B_t_scale,
313+
offs=offs,
314+
block_size=block_size,
315+
out_dtype=out_dtype,
316+
)
317+
return out
318+
319+
@staticmethod
320+
def backward(ctx, grad_output: torch.Tensor):
321+
raise NotImplementedError
322+
323+
324+
def _to_mxfp8_3d_expert_weights_dim1(
325+
w_t: torch.Tensor, # (num_experts, K, N)
326+
block_size: int = 32,
327+
elem_dtype: torch.dtype = torch.float8_e4m3fn,
328+
) -> Tuple[torch.Tensor, torch.Tensor]:
329+
"""Convert a 3D tensor of shape (experts, K, N) to MXFP8 format along dim1.
330+
Args:
331+
x (torch.Tensor): Input tensor to be converted.
332+
block_size (int): Block size for MXFP8 quantization.
333+
elem_dtype (torch.dtype): Element dtype for MXFP8 quantization.
334+
Returns:
335+
Tuple[torch.Tensor, torch.Tensor]: Converted tensor and scale tensor.
336+
- scale shape: (expets, K // block_size, N)
337+
- output shape: (experts, K, N)
338+
"""
339+
# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
340+
w_scale, w_mx = to_mx(
341+
w_t.transpose(-2, -1).contiguous(), elem_dtype=elem_dtype, block_size=block_size
342+
)
343+
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)
344+
return w_t_scale, w_t_mx
345+
346+
271347
def emulated_mxfp8_scaled_grouped_mm(
272348
A_mx: torch.Tensor,
273349
A_scale: torch.Tensor,

0 commit comments

Comments
 (0)