Skip to content

Make scaling type configurable for MoE training #2642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 110 additions & 7 deletions test/prototype/moe_training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
)

from torchao.float8.float8_utils import compute_error
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
from torchao.prototype.moe_training.conversion_utils import (
MoEScalingType,
MoETrainingConfig,
)
from torchao.quantization.quant_api import quantize_

from .testing_utils import _validate_model_conversion
Expand Down Expand Up @@ -72,7 +75,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return False

# quantize test model
config = MoETrainingConfig()
config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE)
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# validate that only the experts were converted
Expand All @@ -99,7 +102,105 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:

# validate output
out_sqnr = compute_error(out, ref_out)
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
min_out_sqnr = 29.0
assert out_sqnr.item() >= min_out_sqnr, (
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
)

# compute loss
labels = torch.ones_like(ref_out)
ref_loss = F.mse_loss(ref_out, labels)
out_loss = F.mse_loss(out, labels)

# backward pass
ref_loss.backward()
out_loss.backward()

# validate input gradient
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
min_input_grad_sqnr = 29.0
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
)

# validate param gradients
min_param_grad_sqnr = 25.0
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
param_grad_sqnr = compute_error(param1.grad, param2.grad)
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
)


@pytest.mark.parametrize(
"target_fqns",
[
["experts"],
["does.not.exist"],
],
)
def test_moe_mxfp8_training(target_fqns: list[str]):
block_size = 32

# Token groups must be divisible by 32 for mxfp8
set_token_group_alignment_size_m(block_size)

model_args = TransformerModelArgs(
moe_enabled=True,
num_experts=8,
dim=256,
multiple_of=block_size,
ffn_dim_multiplier=1.0,
)
init_std = 0.02
device = torch.device("cuda")

# reference bf16 MoE
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
torch.manual_seed(42)
ref_model.init_weights(init_std, device)

# target MoE for testing conversion
model = copy.deepcopy(ref_model)

# assert starting params are identical for both models
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
assert torch.equal(param1, param2)

# convert MoE to float8 training
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
for target_fqn in target_fqns:
if target_fqn in cur_fqn:
return True
return False

# quantize test model
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# validate that only the experts were converted
_validate_model_conversion(
model,
target_fqns=target_fqns,
)

# inputs
batch, seq, dim = 8, 2048, 256
ref_x = torch.randn(
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
)
x = ref_x.detach().clone().requires_grad_(True)

# forward pass
ref_out = ref_model(ref_x)
out = model(x)

# validate output
out_sqnr = compute_error(out, ref_out)
min_out_sqnr = 25.0
assert out_sqnr.item() >= min_out_sqnr, (
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
)

# compute loss
labels = torch.ones_like(ref_out)
Expand All @@ -112,13 +213,15 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:

# validate input gradient
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
assert input_grad_sqnr.item() >= 30.0, (
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}."
min_input_grad_sqnr = 25.0
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
)

# validate param gradients
min_param_grad_sqnr = 21.0
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
param_grad_sqnr = compute_error(param1.grad, param2.grad)
assert param_grad_sqnr.item() >= 25.0, (
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
)
17 changes: 14 additions & 3 deletions torchao/prototype/moe_training/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import logging
from enum import Enum
from typing import Callable, Optional

from torch import nn

from torchao.core.config import AOBaseConfig
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)

logger: logging.Logger = logging.getLogger(__name__)


class MoEScalingType(Enum):
FP8_ROWWISE = "fp8_rowwise"
MXFP8 = "mxfp8"


class MoETrainingConfig(AOBaseConfig):
"""
The MoETrainingConfig is specifically designed to be used on MoE models using
Expand All @@ -36,6 +41,10 @@ class MoETrainingConfig(AOBaseConfig):
For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor.
"""

def __init__(self, scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE):
super().__init__()
self.scaling_type = scaling_type


@register_quantize_module_handler(MoETrainingConfig)
def _moe_training_transform(
Expand Down Expand Up @@ -76,6 +85,8 @@ def _swap_params(
Returns:
nn.Module: The modified module with swapped linear layers.
"""
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor

if isinstance(module, nn.Parameter) and (
module_filter_fn is None or module_filter_fn(module, "")
):
Expand All @@ -84,7 +95,7 @@ def _swap_params(
f"Does not support a root nn.Parameter with children: {module}"
)
if not isinstance(module.data, ScaledGroupedMMTensor):
new_data = ScaledGroupedMMTensor(module.data)
new_data = ScaledGroupedMMTensor(module.data, config.scaling_type)
return nn.Parameter(new_data, requires_grad=module.requires_grad)
return module

Expand All @@ -110,7 +121,7 @@ def post_order_traversal(
for param_name, param in module.named_parameters(recurse=False):
if not isinstance(param.data, ScaledGroupedMMTensor):
new_param = nn.Parameter(
ScaledGroupedMMTensor(param.data),
ScaledGroupedMMTensor(param.data, config.scaling_type),
requires_grad=param.requires_grad,
)
setattr(module, param_name, new_param)
Expand Down
31 changes: 23 additions & 8 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from torchao.float8.config import ScalingGranularity
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
from torchao.prototype.moe_training.kernels import (
triton_fp8_col_major_jagged_colwise_scales,
triton_fp8_row_major_jagged_rowwise_scales,
Expand All @@ -30,6 +31,7 @@ def _scaled_grouped_mm(
B_t: torch.Tensor,
offs: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE,
) -> torch.Tensor:
"""
This function performs dynamic float8 quantization with row-wise scaling
Expand All @@ -43,14 +45,27 @@ def _scaled_grouped_mm(
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
"""
# TODO: Remove once prototype is more mature. This is currently very useful for development and debugging.
logger.info("Using scaled_grouped_mm")
return _Float8GroupedMM.apply(
A,
B_t,
offs,
out_dtype,
)
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
if scaling_type == MoEScalingType.FP8_ROWWISE:
logger.info("Using fp8 rowwise scaled_grouped_mm")
return _Float8GroupedMM.apply(
A,
B_t,
offs,
out_dtype,
)
elif scaling_type == MoEScalingType.MXFP8:
logger.info("Using mxfp8 scaled_grouped_mm")
block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow?
return _MXFP8GroupedMM.apply(
A,
B_t,
offs,
block_size,
out_dtype,
)
else:
raise ValueError(f"Unsupported scaling type {scaling_type}")


class _Float8GroupedMM(torch.autograd.Function):
Expand Down
34 changes: 26 additions & 8 deletions torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.distributed.fsdp import MixedPrecisionPolicy

from torchao.prototype.moe_training import _scaled_grouped_mm
from torchao.prototype.moe_training.conversion_utils import MoEScalingType

logger: logging.Logger = logging.getLogger(__name__)

Expand All @@ -41,15 +42,17 @@ class ScaledGroupedMMTensor(torch.Tensor):
differentiable _scaled_grouped_mm autograd function.
"""

scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE
grouped_mm_func_name = "_grouped_mm"
offs_arg_name = "offs"

@staticmethod
def __new__(
cls,
tensor: torch.Tensor,
scaling_type: MoEScalingType,
):
return torch.Tensor._make_wrapper_subclass(
self = torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
strides=tensor.stride(),
Expand All @@ -61,12 +64,16 @@ def __new__(
pin_memory=tensor.is_pinned(),
requires_grad=tensor.requires_grad,
)
self.scaling_type = scaling_type
return self

def __init__(
self,
tensor: torch.Tensor,
scaling_type: MoEScalingType,
):
self._data = tensor
self.scaling_type = scaling_type

@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
Expand All @@ -80,12 +87,20 @@ def __torch_function__(cls, func, types, args, kwargs={}):
# used for shared experts. This is basically the grouped_mm
# kernel handling a bmm.
A, B = args[0], args[1]
assert not isinstance(A, ScaledGroupedMMTensor), (
"A should not be a ScaledGroupedMMTensor"
)
assert isinstance(B, ScaledGroupedMMTensor), (
"B should be a ScaledGroupedMMTensor"
)
scaling_type = B.scaling_type
A_is_2d = A.dim() == 2
B_is_3d = B.dim() == 3
has_offs = kwargs.get(cls.offs_arg_name) is not None
if A_is_2d and B_is_3d and has_offs:
return _scaled_grouped_mm(
*args,
scaling_type=scaling_type,
**kwargs,
)

Expand All @@ -97,8 +112,9 @@ def __torch_function__(cls, func, types, args, kwargs={}):
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs={}):
# detach is special case
scaling_type = args[0].scaling_type
if func == torch.ops.aten.detach.default:
return ScaledGroupedMMTensor(args[0]._data)
return ScaledGroupedMMTensor(args[0]._data, scaling_type)

# unwrap args/kwargs
unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x
Expand All @@ -116,22 +132,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
return pytree.tree_map_only(
torch.Tensor,
lambda x: ScaledGroupedMMTensor(x),
lambda x: ScaledGroupedMMTensor(x, scaling_type),
out,
)

def __repr__(self):
return f"ScaledGroupedMMTensor(data={self._data})"
return f"ScaledGroupedMMTensor(data={self._data}, scaling_type={self.scaling_type})"

def __tensor_flatten__(self):
# Metadata is empty but needed to make the subclass traceable for torch.compile.
metadata = {}
metadata = {"scaling_type": self.scaling_type}
return ["_data"], metadata

@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
return ScaledGroupedMMTensor(
inner_tensors["_data"],
flatten_spec["scaling_type"],
)

# fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
Expand All @@ -158,14 +174,16 @@ def fsdp_post_all_gather(
):
(data,) = all_gather_outputs

# For training step 1+, out=unshared param.
# For training step 1+, out=unsharded param.
if out is not None:
if isinstance(out, ScaledGroupedMMTensor):
out_data = out._data
out.scaling_type = self.scaling_type
elif isinstance(out, DTensor) and isinstance(
out._local_tensor, ScaledGroupedMMTensor
):
out_data = out._local_tensor._data
out._local_tensor.scaling_type = self.scaling_type
else:
raise RuntimeError(
f"expect out to be ScaledGroupedMMTensor or DTensor with local_tensor=ScaledGroupedMM, but got {type(out)}"
Expand All @@ -188,6 +206,6 @@ def fsdp_post_all_gather(
return

# For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.
output = ScaledGroupedMMTensor(data)
output = ScaledGroupedMMTensor(data, self.scaling_type)
inner_tensors = (data,)
return output, inner_tensors