-
Notifications
You must be signed in to change notification settings - Fork 309
[moe training] use smaller block sizes for per group scaling kernels to improve perf #2668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2668
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ No Failures, 8 PendingAs of commit 351a2a9 with merge base 18edd01 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
03c179b
to
639845e
Compare
639845e
to
1366277
Compare
1366277
to
bcb7403
Compare
…to improve perf stack-info: PR: #2668, branch: danielvegamyhre/stack/27
bcb7403
to
351a2a9
Compare
@xmfan somehow these changes break our test with torch.compile, but I don't see how that's possible since it's just changing the triton autotuner configs and kernel internals.... would you mind taking a quick look please? Repro:
Error: test/prototype/moe_training/test_training.py F [100%]
========================================================= FAILURES =========================================================
_______________________________________ test_moe_float8_training[True-target_fqns0] ________________________________________
target_fqns = ['experts'], compile = True
@pytest.mark.parametrize(
"target_fqns",
[
["experts"],
["does.not.exist"],
],
)
@pytest.mark.parametrize("compile", [False, True])
def test_moe_float8_training(target_fqns: list[str], compile: bool):
# Set token group alignment size to 16. This is required so that
# each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
# has the contraction dim be divisible by 16. 16 byte alignment is required
# for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
set_token_group_alignment_size_m(16)
model_args = TransformerModelArgs(
moe_enabled=True,
num_experts=8,
dim=256,
)
init_std = 0.02
device = torch.device("cuda")
# reference bf16 MoE
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
torch.manual_seed(42)
ref_model.init_weights(init_std, device)
# target MoE for testing conversion
model = copy.deepcopy(ref_model)
# assert starting params are identical for both models
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
assert torch.equal(param1, param2)
# convert MoE to float8 training
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
for target_fqn in target_fqns:
if target_fqn in cur_fqn:
return True
return False
# quantize test model
config = MoETrainingConfig()
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
# validate that only the experts were converted
_validate_model_conversion(
model,
target_fqns=target_fqns,
)
if compile:
# TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
model = torch.compile(model, fullgraph=False)
ref_model = torch.compile(ref_model, fullgraph=False)
# inputs
batch, seq, dim = 8, 2048, 256
ref_x = torch.randn(
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
)
x = ref_x.detach().clone().requires_grad_(True)
# forward pass
ref_out = ref_model(ref_x)
> out = model(x)
^^^^^^^^
test/prototype/moe_training/test_training.py:98:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py:413: in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py:1786: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py:817: in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/compile_fx.py:979: in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/compile_fx.py:963: in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/compile_fx.py:1646: in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/compile_fx.py:1281: in codegen_and_compile
_recursive_post_grad_passes(gm, is_inference=is_inference)
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/compile_fx.py:532: in _recursive_post_grad_passes
post_grad_passes(gm, is_inference)
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/fx_passes/post_grad.py:277: in post_grad_passes
).apply_graph_pass(decompose_triton_kernel_wrapper_functional)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/fx/passes/graph_transform_observer.py:87: in apply_graph_pass
return pass_fn(self.gm.graph)
^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/fx_passes/post_grad.py:1215: in decompose_triton_kernel_wrapper_functional
graph_pass.apply(graph)
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/pattern_matcher.py:1978: in apply
entry.apply(m, graph, node)
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/pattern_matcher.py:1122: in apply
self.handler(match, *match.args, **match.kwargs)
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/fx_passes/post_grad.py:1213: in _
match.replace_by_example(decomp, flat_args, run_functional_passes=False)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Match(..., [], {'kernel_idx': 2, 'constant_args_idx': 8, 'grid': [(512, 8, 1), (512, 8, 1), (512, 8, 1), (512, 8, 1), ...': -448.0, 'fp8_dtype_max': 448.0, 'round_scales_to_power_of_2': True, 'EPS': 1e-12}, 'tensors_to_clone': ['out_ptr']})
replacement_fn = <function decompose_triton_kernel_wrapper_functional.<locals>._.<locals>.decomp at 0x7f156eff3740>
args = [2, 8, 512, 8, 1, 512, ...]
trace_fn = functools.partial(<function fwd_only at 0x7f16567c5d00>, run_functional_passes=False)
run_functional_passes = False
def replace_by_example(
self,
replacement_fn: ReplaceFn,
args: Sequence[Any],
trace_fn: Optional[TraceFn] = None,
run_functional_passes: bool = True,
) -> None:
"""Replace with a graph generated by tracing the replacement_fn.
Args:
run_functional_passes (bool). If we should run passes that
assume functional IR (like DCE, remove_noop_ops), on the
replacement graph.
"""
from torch._inductor.virtualized import NullHandler, V
context = (
V.fake_mode
if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None))
else contextlib.nullcontext()
)
def should_propagate_eager_input_vals(nodes: list[torch.fx.Node]) -> bool:
if len(nodes) != 1:
return False
node = nodes[0]
if "eager_input_vals" not in node.meta:
return False
return node.target in OrderedSet(
[
torch.ops.higher_order.triton_kernel_wrapper_functional,
torch.ops.higher_order.auto_functionalized,
torch.ops.higher_order.auto_functionalized_v2,
]
)
with context:
if trace_fn is None:
trace_fn = functools.partial(
fwd_only, run_functional_passes=run_functional_passes
)
if should_propagate_eager_input_vals(self.nodes):
# Our strategy is:
# 1) trace out the graph with eager_input_vals (which have accurate eager-mode metadata)
# 2) trace out the graph with vals (which have the accurate Inductor metadata)
# 3) Propagate the eager_input_vals from the first graph to the second.
# 4) Use the second graph as the replacement graph.
# Construct a map of node -> FakeTensor val in eager_input_vals
node_to_val = {}
fake_args, fake_kwargs = self.nodes[0].meta["eager_input_vals"]
fake_kwargs = {**fake_kwargs}
match_args, match_kwargs = tuple(self.args), self.kwargs
def record(node: torch.fx.Node, val: Any) -> None:
if isinstance(node, torch.fx.Node):
node_to_val[node] = val
torch.utils._pytree.tree_map(
record, (match_args, match_kwargs), (fake_args, fake_kwargs)
)
# map args to their FakeTensor val in eager_input_vals
example_vals = torch.fx.map_arg(args, lambda arg: node_to_val[arg])
# first graph
graph_with_eager_vals = trace_fn(replacement_fn, example_vals)
# second graph
example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"])
replacement = trace_fn(graph_with_eager_vals, example_vals)
# propagate metadata from first graph to second
# NB: This assertion might not be true in general, but it is true for
# the two use cases we have
# (triton_kernel_wrapper_functional, auto_functionalized)
> assert len(graph_with_eager_vals.graph.nodes) == len(
replacement.graph.nodes
)
E torch._inductor.exc.InductorError: AssertionError:
E
E Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/pattern_matcher.py:311: InductorError |
@zou3519 @oulgen I was told by Simon you might have some insight into what appears to be a bug in the |
I have been looking for a repro for this actually |
Stacked PRs:
[moe training] use smaller block sizes for per group scaling kernels to improve perf
I noticed in a trace of fp8 rowwise MoE training for llama4, the triton kernel that does per token group scaling along dim0 was unexpectedly slow for some reason.
I took a look with NCU and found "grid too small" warning, meaning that I am not parallelizing sufficiently to utilize all SMs.
To fix this, I adjusted the kernel autotuner config to use much smaller block sizes.
Kernel benchmarking shows a 5x - 9x (!) speedup now:
Old autotuner configs
New: autotuner configs