Skip to content

Commit bb05933

Browse files
Make scaling type configurable for MoE training
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
1 parent 1f0d2bb commit bb05933

File tree

4 files changed

+166
-20
lines changed

4 files changed

+166
-20
lines changed

test/prototype/moe_training/test_training.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,19 @@
1212
)
1313

1414
from torchao.float8.float8_utils import compute_error
15-
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
15+
from torchao.prototype.moe_training.conversion_utils import (
16+
MoEScalingType,
17+
MoETrainingConfig,
18+
)
1619
from torchao.quantization.quant_api import quantize_
1720

1821
from .testing_utils import _validate_model_conversion
1922

2023
# this test requires torchtitan
2124
try:
25+
from torchtitan.experiments.llama4.infra.expert_parallel import (
26+
set_token_group_alignment_size_m,
27+
)
2228
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
2329
from torchtitan.experiments.llama4.model.moe import MoE
2430
except ImportError:
@@ -63,7 +69,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
6369
return False
6470

6571
# quantize test model
66-
config = MoETrainingConfig()
72+
config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE)
6773
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
6874

6975
# validate that only the experts were converted
@@ -108,3 +114,98 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
108114
assert param_grad_sqnr.item() >= 25.0, (
109115
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
110116
)
117+
118+
119+
@pytest.mark.parametrize(
120+
"target_fqns",
121+
[
122+
["experts"],
123+
["does.not.exist"],
124+
],
125+
)
126+
def test_moe_mxfp8_training(target_fqns: list[str]):
127+
block_size = 32
128+
129+
# Token groups must be divisible by 32 for mxfp8
130+
set_token_group_alignment_size_m(block_size)
131+
132+
model_args = TransformerModelArgs(
133+
moe_enabled=True,
134+
num_experts=8,
135+
dim=256,
136+
multiple_of=block_size,
137+
ffn_dim_multiplier=1.0,
138+
)
139+
init_std = 0.02
140+
device = torch.device("cuda")
141+
142+
# reference bf16 MoE
143+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
144+
torch.manual_seed(42)
145+
ref_model.init_weights(init_std, device)
146+
147+
# target MoE for testing conversion
148+
model = copy.deepcopy(ref_model)
149+
150+
# assert starting params are identical for both models
151+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
152+
assert torch.equal(param1, param2)
153+
154+
# convert MoE to float8 training
155+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
156+
for target_fqn in target_fqns:
157+
if target_fqn in cur_fqn:
158+
return True
159+
return False
160+
161+
# quantize test model
162+
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
163+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
164+
165+
# validate that only the experts were converted
166+
_validate_model_conversion(
167+
model,
168+
target_fqns=target_fqns,
169+
)
170+
171+
# inputs
172+
batch, seq, dim = 8, 2048, 256
173+
ref_x = torch.randn(
174+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
175+
)
176+
x = ref_x.detach().clone().requires_grad_(True)
177+
178+
# forward pass
179+
ref_out = ref_model(ref_x)
180+
out = model(x)
181+
182+
# validate output
183+
out_sqnr = compute_error(out, ref_out)
184+
min_out_sqnr = 25.0
185+
assert out_sqnr.item() >= min_out_sqnr, (
186+
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
187+
)
188+
189+
# compute loss
190+
labels = torch.ones_like(ref_out)
191+
ref_loss = F.mse_loss(ref_out, labels)
192+
out_loss = F.mse_loss(out, labels)
193+
194+
# backward pass
195+
ref_loss.backward()
196+
out_loss.backward()
197+
198+
# validate input gradient
199+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
200+
min_input_grad_sqnr = 25.0
201+
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
202+
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
203+
)
204+
205+
# validate param gradients
206+
min_param_grad_sqnr = 22.0
207+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
208+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
209+
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (
210+
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
211+
)

torchao/prototype/moe_training/conversion_utils.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,24 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import logging
7+
from enum import Enum
78
from typing import Callable, Optional
89

910
from torch import nn
1011

1112
from torchao.core.config import AOBaseConfig
12-
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
1313
from torchao.quantization.transform_module import (
1414
register_quantize_module_handler,
1515
)
1616

1717
logger: logging.Logger = logging.getLogger(__name__)
1818

1919

20+
class MoEScalingType(Enum):
21+
FP8_ROWWISE = "fp8_rowwise"
22+
MXFP8 = "mxfp8"
23+
24+
2025
class MoETrainingConfig(AOBaseConfig):
2126
"""
2227
The MoETrainingConfig is specifically designed to be used on MoE models using
@@ -36,6 +41,10 @@ class MoETrainingConfig(AOBaseConfig):
3641
For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor.
3742
"""
3843

44+
def __init__(self, scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE):
45+
super().__init__()
46+
self.scaling_type = scaling_type
47+
3948

4049
@register_quantize_module_handler(MoETrainingConfig)
4150
def _moe_training_transform(
@@ -76,6 +85,8 @@ def _swap_params(
7685
Returns:
7786
nn.Module: The modified module with swapped linear layers.
7887
"""
88+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
89+
7990
if isinstance(module, nn.Parameter) and (
8091
module_filter_fn is None or module_filter_fn(module, "")
8192
):
@@ -84,7 +95,7 @@ def _swap_params(
8495
f"Does not support a root nn.Parameter with children: {module}"
8596
)
8697
if not isinstance(module.data, ScaledGroupedMMTensor):
87-
new_data = ScaledGroupedMMTensor(module.data)
98+
new_data = ScaledGroupedMMTensor(module.data, config.scaling_type)
8899
return nn.Parameter(new_data, requires_grad=module.requires_grad)
89100
return module
90101

@@ -110,7 +121,7 @@ def post_order_traversal(
110121
for param_name, param in module.named_parameters(recurse=False):
111122
if not isinstance(param.data, ScaledGroupedMMTensor):
112123
new_param = nn.Parameter(
113-
ScaledGroupedMMTensor(param.data),
124+
ScaledGroupedMMTensor(param.data, config.scaling_type),
114125
requires_grad=param.requires_grad,
115126
)
116127
setattr(module, param_name, new_param)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from torchao.float8.config import ScalingGranularity
1313
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
14+
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
1415
from torchao.prototype.moe_training.kernels import (
1516
triton_fp8_col_major_jagged_colwise_scales,
1617
triton_fp8_row_major_jagged_rowwise_scales,
@@ -30,6 +31,7 @@ def _scaled_grouped_mm(
3031
B_t: torch.Tensor,
3132
offs: Optional[torch.Tensor] = None,
3233
out_dtype: Optional[torch.dtype] = torch.bfloat16,
34+
scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE,
3335
) -> torch.Tensor:
3436
"""
3537
This function performs dynamic float8 quantization with row-wise scaling
@@ -43,14 +45,27 @@ def _scaled_grouped_mm(
4345
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
4446
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4547
"""
46-
# TODO: Remove once prototype is more mature. This is currently very useful for development and debugging.
47-
logger.info("Using scaled_grouped_mm")
48-
return _Float8GroupedMM.apply(
49-
A,
50-
B_t,
51-
offs,
52-
out_dtype,
53-
)
48+
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
49+
if scaling_type == MoEScalingType.FP8_ROWWISE:
50+
logger.info("Using fp8 rowwise scaled_grouped_mm")
51+
return _Float8GroupedMM.apply(
52+
A,
53+
B_t,
54+
offs,
55+
out_dtype,
56+
)
57+
elif scaling_type == MoEScalingType.MXFP8:
58+
logger.info("Using mxfp8 scaled_grouped_mm")
59+
block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow?
60+
return _MXFP8GroupedMM.apply(
61+
A,
62+
B_t,
63+
offs,
64+
block_size,
65+
out_dtype,
66+
)
67+
else:
68+
raise ValueError(f"Unsupported scaling type {scaling_type}")
5469

5570

5671
class _Float8GroupedMM(torch.autograd.Function):

torchao/prototype/moe_training/tensor.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.distributed.fsdp import MixedPrecisionPolicy
1717

1818
from torchao.prototype.moe_training import _scaled_grouped_mm
19+
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
1920

2021
logger: logging.Logger = logging.getLogger(__name__)
2122

@@ -40,15 +41,17 @@ class ScaledGroupedMMTensor(torch.Tensor):
4041
differentiable _scaled_grouped_mm autograd function.
4142
"""
4243

44+
scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE
4345
grouped_mm_func_name = "_grouped_mm"
4446
offs_arg_name = "offs"
4547

4648
@staticmethod
4749
def __new__(
4850
cls,
4951
tensor: torch.Tensor,
52+
scaling_type: MoEScalingType,
5053
):
51-
return torch.Tensor._make_wrapper_subclass(
54+
self = torch.Tensor._make_wrapper_subclass(
5255
cls,
5356
tensor.size(),
5457
strides=tensor.stride(),
@@ -60,12 +63,16 @@ def __new__(
6063
pin_memory=tensor.is_pinned(),
6164
requires_grad=tensor.requires_grad,
6265
)
66+
self.scaling_type = scaling_type
67+
return self
6368

6469
def __init__(
6570
self,
6671
tensor: torch.Tensor,
72+
scaling_type: MoEScalingType,
6773
):
6874
self._data = tensor
75+
self.scaling_type = scaling_type
6976

7077
@classmethod
7178
def __torch_function__(cls, func, types, args, kwargs={}):
@@ -79,12 +86,20 @@ def __torch_function__(cls, func, types, args, kwargs={}):
7986
# used for shared experts. This is basically the grouped_mm
8087
# kernel handling a bmm.
8188
A, B = args[0], args[1]
89+
assert not isinstance(A, ScaledGroupedMMTensor), (
90+
"A should not be a ScaledGroupedMMTensor"
91+
)
92+
assert isinstance(B, ScaledGroupedMMTensor), (
93+
"B should be a ScaledGroupedMMTensor"
94+
)
95+
scaling_type = B.scaling_type
8296
A_is_2d = A.dim() == 2
8397
B_is_3d = B.dim() == 3
8498
has_offs = kwargs.get(cls.offs_arg_name) is not None
8599
if A_is_2d and B_is_3d and has_offs:
86100
return _scaled_grouped_mm(
87101
*args,
102+
scaling_type=scaling_type,
88103
**kwargs,
89104
)
90105

@@ -96,8 +111,9 @@ def __torch_function__(cls, func, types, args, kwargs={}):
96111
@classmethod
97112
def __torch_dispatch__(cls, func, types, args, kwargs={}):
98113
# detach is special case
114+
scaling_type = args[0].scaling_type
99115
if func == torch.ops.aten.detach.default:
100-
return ScaledGroupedMMTensor(args[0]._data)
116+
return ScaledGroupedMMTensor(args[0]._data, scaling_type)
101117

102118
# unwrap args/kwargs
103119
unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x
@@ -115,20 +131,21 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
115131
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
116132
return pytree.tree_map_only(
117133
torch.Tensor,
118-
lambda x: ScaledGroupedMMTensor(x),
134+
lambda x: ScaledGroupedMMTensor(x, scaling_type),
119135
out,
120136
)
121137

122138
def __repr__(self):
123-
return f"ScaledGroupedMMTensor(data={self._data})"
139+
return f"ScaledGroupedMMTensor(data={self._data}, scaling_type={self.scaling_type})"
124140

125141
def __tensor_flatten__(self):
126-
return ["_data"]
142+
return ["_data"], {"scaling_type": self.scaling_type}
127143

128144
@staticmethod
129145
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
130146
return ScaledGroupedMMTensor(
131147
inner_tensors["_data"],
148+
flatten_spec["scaling_type"],
132149
)
133150

134151
# fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
@@ -155,14 +172,16 @@ def fsdp_post_all_gather(
155172
):
156173
(data,) = all_gather_outputs
157174

158-
# For training step 1+, out=unshared param.
175+
# For training step 1+, out=unsharded param.
159176
if out is not None:
160177
if isinstance(out, ScaledGroupedMMTensor):
161178
out_data = out._data
179+
out.scaling_type = self.scaling_type
162180
elif isinstance(out, DTensor) and isinstance(
163181
out._local_tensor, ScaledGroupedMMTensor
164182
):
165183
out_data = out._local_tensor._data
184+
out._local_tensor.scaling_type = self.scaling_type
166185
else:
167186
raise RuntimeError(
168187
f"expect out to be ScaledGroupedMMTensor or DTensor with local_tensor=ScaledGroupedMM, but got {type(out)}"
@@ -185,6 +204,6 @@ def fsdp_post_all_gather(
185204
return
186205

187206
# For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.
188-
output = ScaledGroupedMMTensor(data)
207+
output = ScaledGroupedMMTensor(data, self.scaling_type)
189208
inner_tensors = (data,)
190209
return output, inner_tensors

0 commit comments

Comments
 (0)