Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions test/prototype/moe_training/test_scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,7 @@ def compute_reference_forward(
return output_ref


@pytest.mark.parametrize("M", (1024, 4096))
@pytest.mark.parametrize("K", (1024, 4096))
@pytest.mark.parametrize("N", (1024, 4096))
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
@pytest.mark.parametrize("num_experts", (1, 8, 16))
def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
Expand Down Expand Up @@ -249,3 +247,23 @@ def test_emulate_mxfp8_grouped_gemm(M, K, N, num_experts):
sqnr = compute_error(ref_out, out)
min_sqnr = 27.0
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"


@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
@pytest.mark.parametrize("num_experts", (1, 8, 16))
def test_mxfp8_grouped_gemm_with_dq_fwd(M, K, N, num_experts):
from torchao.prototype.moe_training.scaled_grouped_mm import (
_MXFP8GroupedMM,
)

x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
offs = generate_jagged_offs(num_experts, M)
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()
block_size = 32

out = _MXFP8GroupedMM.apply(x, w_t, offs, block_size, torch.bfloat16)
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
sqnr = compute_error(ref_out, out)
min_sqnr = 27.0
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"
78 changes: 77 additions & 1 deletion torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Optional
from typing import Optional, Tuple

import torch

Expand All @@ -18,6 +18,7 @@
from torchao.prototype.moe_training.utils import (
_is_column_major,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -268,6 +269,81 @@ def backward(ctx, grad_output: torch.Tensor):
return grad_A, grad_B.transpose(-2, -1), None, None, None, None


class _MXFP8GroupedMM(torch.autograd.Function):
"""Differentiable implementation of grouped GEMM with dynamic mxpf8 quantization."""

@staticmethod
def forward(
ctx,
A: torch.Tensor,
B_t: torch.Tensor,
offs: Optional[torch.Tensor] = None,
block_size: int = 32,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
emulated: bool = True,
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an emulated flag and assert that it's True until we have a real kernel, to make the intent crystal clear?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# torchao _scaled_grouped_mm only supports A=2D and B=3D.
assert A.ndim == 2, "A must be 2D"
assert B_t.ndim == 3, "B must be 3D"
assert block_size == 32, "Only block_size=32 is supported"
assert emulated, "Only emulated mxfp8 grouped gemm is supported"

# Cast to mxpf8 across dim -1.
# A_mx shape: (M, K)
# A_scale shape: (M, K//block_size)
A_scale, A_mx = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)

# Cast B_t per-expert to mxfp8 across dim1.
# B_t_mx shape: (E, K, N)
# B_t_scale shape: (E, K//block_size, N)
B_t_scale, B_t_mx = _to_mxfp8_3d_expert_weights_dim1(B_t, block_size=block_size)

# Store what we need for backward.
ctx.save_for_backward(A, B_t, offs)
ctx.out_dtype = out_dtype

# Perform scaled grouped GEMM and return result.
# output = input @ weight.T
# output shape: (M, N)
out = emulated_mxfp8_scaled_grouped_mm(
A_mx,
A_scale,
B_t_mx,
B_t_scale,
offs=offs,
block_size=block_size,
out_dtype=out_dtype,
)
return out

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
raise NotImplementedError


def _to_mxfp8_3d_expert_weights_dim1(
w_t: torch.Tensor, # (num_experts, K, N)
block_size: int = 32,
elem_dtype: torch.dtype = torch.float8_e4m3fn,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert a 3D tensor of shape (experts, K, N) to MXFP8 format along dim1.
Args:
x (torch.Tensor): Input tensor to be converted.
block_size (int): Block size for MXFP8 quantization.
elem_dtype (torch.dtype): Element dtype for MXFP8 quantization.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Converted tensor and scale tensor.
- scale shape: (expets, K // block_size, N)
- output shape: (experts, K, N)
"""
# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
w_scale, w_mx = to_mx(
w_t.transpose(-2, -1).contiguous(), elem_dtype=elem_dtype, block_size=block_size
)
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)
return w_t_scale, w_t_mx


def emulated_mxfp8_scaled_grouped_mm(
A_mx: torch.Tensor,
A_scale: torch.Tensor,
Expand Down
Loading