From a221a9efa88fb7f86c55d98b4c275e7fc4f8992c Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 30 Jul 2025 16:19:19 -0700 Subject: [PATCH] Make scaling type configurable for MoE training stack-info: PR: https://github.com/pytorch/ao/pull/2642, branch: danielvegamyhre/stack/26 --- test/prototype/moe_training/test_training.py | 117 ++++++++++++++++-- .../moe_training/conversion_utils.py | 17 ++- .../moe_training/scaled_grouped_mm.py | 31 +++-- torchao/prototype/moe_training/tensor.py | 34 +++-- 4 files changed, 173 insertions(+), 26 deletions(-) diff --git a/test/prototype/moe_training/test_training.py b/test/prototype/moe_training/test_training.py index 5a86b03804..d08f218842 100644 --- a/test/prototype/moe_training/test_training.py +++ b/test/prototype/moe_training/test_training.py @@ -12,7 +12,10 @@ ) from torchao.float8.float8_utils import compute_error -from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig +from torchao.prototype.moe_training.conversion_utils import ( + MoEScalingType, + MoETrainingConfig, +) from torchao.quantization.quant_api import quantize_ from .testing_utils import _validate_model_conversion @@ -72,7 +75,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model - config = MoETrainingConfig() + config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE) quantize_(model, config=config, filter_fn=moe_module_filter_fn) # validate that only the experts were converted @@ -99,7 +102,105 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate output out_sqnr = compute_error(out, ref_out) - assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}." + min_out_sqnr = 29.0 + assert out_sqnr.item() >= min_out_sqnr, ( + f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." + ) + + # compute loss + labels = torch.ones_like(ref_out) + ref_loss = F.mse_loss(ref_out, labels) + out_loss = F.mse_loss(out, labels) + + # backward pass + ref_loss.backward() + out_loss.backward() + + # validate input gradient + input_grad_sqnr = compute_error(x.grad, ref_x.grad) + min_input_grad_sqnr = 29.0 + assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( + f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." + ) + + # validate param gradients + min_param_grad_sqnr = 25.0 + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + param_grad_sqnr = compute_error(param1.grad, param2.grad) + assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( + f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}." + ) + + +@pytest.mark.parametrize( + "target_fqns", + [ + ["experts"], + ["does.not.exist"], + ], +) +def test_moe_mxfp8_training(target_fqns: list[str]): + block_size = 32 + + # Token groups must be divisible by 32 for mxfp8 + set_token_group_alignment_size_m(block_size) + + model_args = TransformerModelArgs( + moe_enabled=True, + num_experts=8, + dim=256, + multiple_of=block_size, + ffn_dim_multiplier=1.0, + ) + init_std = 0.02 + device = torch.device("cuda") + + # reference bf16 MoE + ref_model = MoE(model_args).to(torch.bfloat16).cuda() + torch.manual_seed(42) + ref_model.init_weights(init_std, device) + + # target MoE for testing conversion + model = copy.deepcopy(ref_model) + + # assert starting params are identical for both models + for param1, param2 in zip(model.parameters(), ref_model.parameters()): + assert torch.equal(param1, param2) + + # convert MoE to float8 training + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in target_fqns: + if target_fqn in cur_fqn: + return True + return False + + # quantize test model + config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8) + quantize_(model, config=config, filter_fn=moe_module_filter_fn) + + # validate that only the experts were converted + _validate_model_conversion( + model, + target_fqns=target_fqns, + ) + + # inputs + batch, seq, dim = 8, 2048, 256 + ref_x = torch.randn( + batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device + ) + x = ref_x.detach().clone().requires_grad_(True) + + # forward pass + ref_out = ref_model(ref_x) + out = model(x) + + # validate output + out_sqnr = compute_error(out, ref_out) + min_out_sqnr = 25.0 + assert out_sqnr.item() >= min_out_sqnr, ( + f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}." + ) # compute loss labels = torch.ones_like(ref_out) @@ -112,13 +213,15 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # validate input gradient input_grad_sqnr = compute_error(x.grad, ref_x.grad) - assert input_grad_sqnr.item() >= 30.0, ( - f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}." + min_input_grad_sqnr = 25.0 + assert input_grad_sqnr.item() >= min_input_grad_sqnr, ( + f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}." ) # validate param gradients + min_param_grad_sqnr = 21.0 for param1, param2 in zip(model.parameters(), ref_model.parameters()): param_grad_sqnr = compute_error(param1.grad, param2.grad) - assert param_grad_sqnr.item() >= 25.0, ( - f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}." + assert param_grad_sqnr.item() >= min_param_grad_sqnr, ( + f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}." ) diff --git a/torchao/prototype/moe_training/conversion_utils.py b/torchao/prototype/moe_training/conversion_utils.py index 2da8186f2d..c6492c9dbd 100644 --- a/torchao/prototype/moe_training/conversion_utils.py +++ b/torchao/prototype/moe_training/conversion_utils.py @@ -4,12 +4,12 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import logging +from enum import Enum from typing import Callable, Optional from torch import nn from torchao.core.config import AOBaseConfig -from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor from torchao.quantization.transform_module import ( register_quantize_module_handler, ) @@ -17,6 +17,11 @@ logger: logging.Logger = logging.getLogger(__name__) +class MoEScalingType(Enum): + FP8_ROWWISE = "fp8_rowwise" + MXFP8 = "mxfp8" + + class MoETrainingConfig(AOBaseConfig): """ The MoETrainingConfig is specifically designed to be used on MoE models using @@ -36,6 +41,10 @@ class MoETrainingConfig(AOBaseConfig): For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor. """ + def __init__(self, scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE): + super().__init__() + self.scaling_type = scaling_type + @register_quantize_module_handler(MoETrainingConfig) def _moe_training_transform( @@ -76,6 +85,8 @@ def _swap_params( Returns: nn.Module: The modified module with swapped linear layers. """ + from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor + if isinstance(module, nn.Parameter) and ( module_filter_fn is None or module_filter_fn(module, "") ): @@ -84,7 +95,7 @@ def _swap_params( f"Does not support a root nn.Parameter with children: {module}" ) if not isinstance(module.data, ScaledGroupedMMTensor): - new_data = ScaledGroupedMMTensor(module.data) + new_data = ScaledGroupedMMTensor(module.data, config.scaling_type) return nn.Parameter(new_data, requires_grad=module.requires_grad) return module @@ -110,7 +121,7 @@ def post_order_traversal( for param_name, param in module.named_parameters(recurse=False): if not isinstance(param.data, ScaledGroupedMMTensor): new_param = nn.Parameter( - ScaledGroupedMMTensor(param.data), + ScaledGroupedMMTensor(param.data, config.scaling_type), requires_grad=param.requires_grad, ) setattr(module, param_name, new_param) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index fd22186939..c997c9cc9b 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -11,6 +11,7 @@ from torchao.float8.config import ScalingGranularity from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated +from torchao.prototype.moe_training.conversion_utils import MoEScalingType from torchao.prototype.moe_training.kernels import ( triton_fp8_col_major_jagged_colwise_scales, triton_fp8_row_major_jagged_rowwise_scales, @@ -30,6 +31,7 @@ def _scaled_grouped_mm( B_t: torch.Tensor, offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, + scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE, ) -> torch.Tensor: """ This function performs dynamic float8 quantization with row-wise scaling @@ -43,14 +45,27 @@ def _scaled_grouped_mm( offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. """ - # TODO: Remove once prototype is more mature. This is currently very useful for development and debugging. - logger.info("Using scaled_grouped_mm") - return _Float8GroupedMM.apply( - A, - B_t, - offs, - out_dtype, - ) + # TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging. + if scaling_type == MoEScalingType.FP8_ROWWISE: + logger.info("Using fp8 rowwise scaled_grouped_mm") + return _Float8GroupedMM.apply( + A, + B_t, + offs, + out_dtype, + ) + elif scaling_type == MoEScalingType.MXFP8: + logger.info("Using mxfp8 scaled_grouped_mm") + block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow? + return _MXFP8GroupedMM.apply( + A, + B_t, + offs, + block_size, + out_dtype, + ) + else: + raise ValueError(f"Unsupported scaling type {scaling_type}") class _Float8GroupedMM(torch.autograd.Function): diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 8b7ff4ae54..1ddd098675 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -16,6 +16,7 @@ from torch.distributed.fsdp import MixedPrecisionPolicy from torchao.prototype.moe_training import _scaled_grouped_mm +from torchao.prototype.moe_training.conversion_utils import MoEScalingType logger: logging.Logger = logging.getLogger(__name__) @@ -41,6 +42,7 @@ class ScaledGroupedMMTensor(torch.Tensor): differentiable _scaled_grouped_mm autograd function. """ + scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE grouped_mm_func_name = "_grouped_mm" offs_arg_name = "offs" @@ -48,8 +50,9 @@ class ScaledGroupedMMTensor(torch.Tensor): def __new__( cls, tensor: torch.Tensor, + scaling_type: MoEScalingType, ): - return torch.Tensor._make_wrapper_subclass( + self = torch.Tensor._make_wrapper_subclass( cls, tensor.size(), strides=tensor.stride(), @@ -61,12 +64,16 @@ def __new__( pin_memory=tensor.is_pinned(), requires_grad=tensor.requires_grad, ) + self.scaling_type = scaling_type + return self def __init__( self, tensor: torch.Tensor, + scaling_type: MoEScalingType, ): self._data = tensor + self.scaling_type = scaling_type @classmethod def __torch_function__(cls, func, types, args, kwargs={}): @@ -80,12 +87,20 @@ def __torch_function__(cls, func, types, args, kwargs={}): # used for shared experts. This is basically the grouped_mm # kernel handling a bmm. A, B = args[0], args[1] + assert not isinstance(A, ScaledGroupedMMTensor), ( + "A should not be a ScaledGroupedMMTensor" + ) + assert isinstance(B, ScaledGroupedMMTensor), ( + "B should be a ScaledGroupedMMTensor" + ) + scaling_type = B.scaling_type A_is_2d = A.dim() == 2 B_is_3d = B.dim() == 3 has_offs = kwargs.get(cls.offs_arg_name) is not None if A_is_2d and B_is_3d and has_offs: return _scaled_grouped_mm( *args, + scaling_type=scaling_type, **kwargs, ) @@ -97,8 +112,9 @@ def __torch_function__(cls, func, types, args, kwargs={}): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs={}): # detach is special case + scaling_type = args[0].scaling_type if func == torch.ops.aten.detach.default: - return ScaledGroupedMMTensor(args[0]._data) + return ScaledGroupedMMTensor(args[0]._data, scaling_type) # unwrap args/kwargs unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x @@ -116,22 +132,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}): # wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass return pytree.tree_map_only( torch.Tensor, - lambda x: ScaledGroupedMMTensor(x), + lambda x: ScaledGroupedMMTensor(x, scaling_type), out, ) def __repr__(self): - return f"ScaledGroupedMMTensor(data={self._data})" + return f"ScaledGroupedMMTensor(data={self._data}, scaling_type={self.scaling_type})" def __tensor_flatten__(self): - # Metadata is empty but needed to make the subclass traceable for torch.compile. - metadata = {} + metadata = {"scaling_type": self.scaling_type} return ["_data"], metadata @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): return ScaledGroupedMMTensor( inner_tensors["_data"], + flatten_spec["scaling_type"], ) # fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81 @@ -158,14 +174,16 @@ def fsdp_post_all_gather( ): (data,) = all_gather_outputs - # For training step 1+, out=unshared param. + # For training step 1+, out=unsharded param. if out is not None: if isinstance(out, ScaledGroupedMMTensor): out_data = out._data + out.scaling_type = self.scaling_type elif isinstance(out, DTensor) and isinstance( out._local_tensor, ScaledGroupedMMTensor ): out_data = out._local_tensor._data + out._local_tensor.scaling_type = self.scaling_type else: raise RuntimeError( f"expect out to be ScaledGroupedMMTensor or DTensor with local_tensor=ScaledGroupedMM, but got {type(out)}" @@ -188,6 +206,6 @@ def fsdp_post_all_gather( return # For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor. - output = ScaledGroupedMMTensor(data) + output = ScaledGroupedMMTensor(data, self.scaling_type) inner_tensors = (data,) return output, inner_tensors