Skip to content

[moe training] use smaller block sizes for per group scaling kernels to improve perf #2668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Aug 2, 2025

Stacked PRs:


[moe training] use smaller block sizes for per group scaling kernels to improve perf

I noticed in a trace of fp8 rowwise MoE training for llama4, the triton kernel that does per token group scaling along dim0 was unexpectedly slow for some reason.

I took a look with NCU and found "grid too small" warning, meaning that I am not parallelizing sufficiently to utilize all SMs.

Screenshot 2025-08-02 at 2 41 28 PM

To fix this, I adjusted the kernel autotuner config to use much smaller block sizes.

Kernel benchmarking shows a 5x - 9x (!) speedup now:

Old autotuner configs

input_shape      n_groups  high_precision_dtype      torch_time_us    triton_time_us
-------------  ----------  ----------------------  ---------------  ----------------
(256, 4096)             4  torch.bfloat16                  1378.24           457.888
(256, 4096)             8  torch.bfloat16                  2806.5            264.992
(256, 4096)            16  torch.bfloat16                  5485.47           171.072
(4096, 4096)            4  torch.bfloat16                  1479.42           546.656
(4096, 4096)            8  torch.bfloat16                  2787.07           539.952
(4096, 4096)           16  torch.bfloat16                  5436.42           728.352
(65536, 4096)           4  torch.bfloat16                  7814.5           7410.94
(65536, 4096)           8  torch.bfloat16                  8674.78          7995.66
(65536, 4096)          16  torch.bfloat16                 10420.9           9111.55

New: autotuner configs

input_shape      n_groups  high_precision_dtype      torch_time_us    triton_time_us
-------------  ----------  ----------------------  ---------------  ----------------
(256, 4096)             4  torch.bfloat16                  1804.35            52.48
(256, 4096)             8  torch.bfloat16                  2934.3             53.568
(256, 4096)            16  torch.bfloat16                  5674.56            51.712
(4096, 4096)            4  torch.bfloat16                  1494.94            80.912
(4096, 4096)            8  torch.bfloat16                  2849.54            86.496
(4096, 4096)           16  torch.bfloat16                  5531.39           108.448
(65536, 4096)           4  torch.bfloat16                  7794.98          1049.63
(65536, 4096)           8  torch.bfloat16                  8661.28          1180.59
(65536, 4096)          16  torch.bfloat16                 10419             1603.23

Copy link

pytorch-bot bot commented Aug 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2668

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⏳ No Failures, 8 Pending

As of commit 351a2a9 with merge base 18edd01 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 2, 2025
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/27 branch from 03c179b to 639845e Compare August 2, 2025 21:39
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 2, 2025
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/27 branch from 639845e to 1366277 Compare August 4, 2025 14:39
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/27 branch from 1366277 to bcb7403 Compare August 4, 2025 14:52
…to improve perf

stack-info: PR: #2668, branch: danielvegamyhre/stack/27
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/27 branch from bcb7403 to 351a2a9 Compare August 5, 2025 15:29
@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Aug 5, 2025

@xmfan somehow these changes break our test with torch.compile, but I don't see how that's possible since it's just changing the triton autotuner configs and kernel internals.... would you mind taking a quick look please?

Repro:

  • pytest test/prototype/moe_training/test_training.py -k test_moe_float8_training[True-target_fqns0]

Error:

test/prototype/moe_training/test_training.py F                                                                       [100%]

========================================================= FAILURES =========================================================
_______________________________________ test_moe_float8_training[True-target_fqns0] ________________________________________

target_fqns = ['experts'], compile = True

    @pytest.mark.parametrize(
        "target_fqns",
        [
            ["experts"],
            ["does.not.exist"],
        ],
    )
    @pytest.mark.parametrize("compile", [False, True])
    def test_moe_float8_training(target_fqns: list[str], compile: bool):
        # Set token group alignment size to 16. This is required so that
        # each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
        # has the contraction dim be divisible by 16. 16 byte alignment is required
        # for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.
        set_token_group_alignment_size_m(16)
        model_args = TransformerModelArgs(
            moe_enabled=True,
            num_experts=8,
            dim=256,
        )
        init_std = 0.02
        device = torch.device("cuda")
    
        # reference bf16 MoE
        ref_model = MoE(model_args).to(torch.bfloat16).cuda()
        torch.manual_seed(42)
        ref_model.init_weights(init_std, device)
    
        # target MoE for testing conversion
        model = copy.deepcopy(ref_model)
    
        # assert starting params are identical for both models
        for param1, param2 in zip(model.parameters(), ref_model.parameters()):
            assert torch.equal(param1, param2)
    
        # convert MoE to float8 training
        def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
            for target_fqn in target_fqns:
                if target_fqn in cur_fqn:
                    return True
            return False
    
        # quantize test model
        config = MoETrainingConfig()
        quantize_(model, config=config, filter_fn=moe_module_filter_fn)
    
        # validate that only the experts were converted
        _validate_model_conversion(
            model,
            target_fqns=target_fqns,
        )
    
        if compile:
            # TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
            model = torch.compile(model, fullgraph=False)
            ref_model = torch.compile(ref_model, fullgraph=False)
    
        # inputs
        batch, seq, dim = 8, 2048, 256
        ref_x = torch.randn(
            batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
        )
        x = ref_x.detach().clone().requires_grad_(True)
    
        # forward pass
        ref_out = ref_model(ref_x)
>       out = model(x)
              ^^^^^^^^

test/prototype/moe_training/test_training.py:98: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py:413: in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py:817: in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/compile_fx.py:979: in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/compile_fx.py:963: in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/compile_fx.py:1646: in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/compile_fx.py:1281: in codegen_and_compile
    _recursive_post_grad_passes(gm, is_inference=is_inference)
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/compile_fx.py:532: in _recursive_post_grad_passes
    post_grad_passes(gm, is_inference)
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/fx_passes/post_grad.py:277: in post_grad_passes
    ).apply_graph_pass(decompose_triton_kernel_wrapper_functional)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/fx/passes/graph_transform_observer.py:87: in apply_graph_pass
    return pass_fn(self.gm.graph)
           ^^^^^^^^^^^^^^^^^^^^^^
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/fx_passes/post_grad.py:1215: in decompose_triton_kernel_wrapper_functional
    graph_pass.apply(graph)
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/pattern_matcher.py:1978: in apply
    entry.apply(m, graph, node)
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/pattern_matcher.py:1122: in apply
    self.handler(match, *match.args, **match.kwargs)
../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/fx_passes/post_grad.py:1213: in _
    match.replace_by_example(decomp, flat_args, run_functional_passes=False)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Match(..., [], {'kernel_idx': 2, 'constant_args_idx': 8, 'grid': [(512, 8, 1), (512, 8, 1), (512, 8, 1), (512, 8, 1), ...': -448.0, 'fp8_dtype_max': 448.0, 'round_scales_to_power_of_2': True, 'EPS': 1e-12}, 'tensors_to_clone': ['out_ptr']})
replacement_fn = <function decompose_triton_kernel_wrapper_functional.<locals>._.<locals>.decomp at 0x7f156eff3740>
args = [2, 8, 512, 8, 1, 512, ...]
trace_fn = functools.partial(<function fwd_only at 0x7f16567c5d00>, run_functional_passes=False)
run_functional_passes = False

    def replace_by_example(
        self,
        replacement_fn: ReplaceFn,
        args: Sequence[Any],
        trace_fn: Optional[TraceFn] = None,
        run_functional_passes: bool = True,
    ) -> None:
        """Replace with a graph generated by tracing the replacement_fn.
    
        Args:
            run_functional_passes (bool). If we should run passes that
                assume functional IR (like DCE, remove_noop_ops), on the
                replacement graph.
    
        """
        from torch._inductor.virtualized import NullHandler, V
    
        context = (
            V.fake_mode
            if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None))
            else contextlib.nullcontext()
        )
    
        def should_propagate_eager_input_vals(nodes: list[torch.fx.Node]) -> bool:
            if len(nodes) != 1:
                return False
            node = nodes[0]
            if "eager_input_vals" not in node.meta:
                return False
            return node.target in OrderedSet(
                [
                    torch.ops.higher_order.triton_kernel_wrapper_functional,
                    torch.ops.higher_order.auto_functionalized,
                    torch.ops.higher_order.auto_functionalized_v2,
                ]
            )
    
        with context:
            if trace_fn is None:
                trace_fn = functools.partial(
                    fwd_only, run_functional_passes=run_functional_passes
                )
    
            if should_propagate_eager_input_vals(self.nodes):
                # Our strategy is:
                # 1) trace out the graph with eager_input_vals (which have accurate eager-mode metadata)
                # 2) trace out the graph with vals (which have the accurate Inductor metadata)
                # 3) Propagate the eager_input_vals from the first graph to the second.
                # 4) Use the second graph as the replacement graph.
    
                # Construct a map of node -> FakeTensor val in eager_input_vals
                node_to_val = {}
    
                fake_args, fake_kwargs = self.nodes[0].meta["eager_input_vals"]
                fake_kwargs = {**fake_kwargs}
                match_args, match_kwargs = tuple(self.args), self.kwargs
    
                def record(node: torch.fx.Node, val: Any) -> None:
                    if isinstance(node, torch.fx.Node):
                        node_to_val[node] = val
    
                torch.utils._pytree.tree_map(
                    record, (match_args, match_kwargs), (fake_args, fake_kwargs)
                )
                # map args to their FakeTensor val in eager_input_vals
                example_vals = torch.fx.map_arg(args, lambda arg: node_to_val[arg])
    
                # first graph
                graph_with_eager_vals = trace_fn(replacement_fn, example_vals)
    
                # second graph
                example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"])
                replacement = trace_fn(graph_with_eager_vals, example_vals)
    
                # propagate metadata from first graph to second
                # NB: This assertion might not be true in general, but it is true for
                # the two use cases we have
                # (triton_kernel_wrapper_functional, auto_functionalized)
>               assert len(graph_with_eager_vals.graph.nodes) == len(
                    replacement.graph.nodes
                )
E               torch._inductor.exc.InductorError: AssertionError: 
E               
E               Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

../.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_inductor/pattern_matcher.py:311: InductorError

@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Aug 5, 2025

@zou3519 @oulgen I was told by Simon you might have some insight into what appears to be a bug in the wrap_triton api, see #2668 (comment) for context. Could you please take a look?

@zou3519
Copy link
Contributor

zou3519 commented Aug 5, 2025

I have been looking for a repro for this actually

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants