-
Notifications
You must be signed in to change notification settings - Fork 416
[Draft] Integrate TE/JAX GroupedGEMM for MoE. #2319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Draft] Integrate TE/JAX GroupedGEMM for MoE. #2319
Conversation
@@ -1,4 +1,5 @@ | |||
# Copyright 2023–2025 Google LLC | |||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kyle-meggs @ultrons - I have no idea what our copyright policies are, curious what your thoughts are
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on this copyright comment
What Docker version did you use to get working DeepSeek-v3 model training on GPU? I see OOMs even on 256 H200s when adapting the TPU configuration to GPUs on the initial_state_partial when training from scratch. Also what MFU are you getting with DSV3 through MaxText? |
I also needed this PR to get cudnn_attention working with the latest JAX/TransformerEngine: #2303 |
We are using 25-08-14 nightly-built JAX container from JAX-Toolbox. Our experiments are under token-dropping + token-per-expert=8. |
@mingxu1067 Could you provide the train command you used to test this / run DeepSeekV3 on GPUs? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change! Please attach your tests in the PR description.
@@ -1,4 +1,5 @@ | |||
# Copyright 2023–2025 Google LLC | |||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on this copyright comment
src/MaxText/layers/moe.py
Outdated
scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, | ||
fwd_dtype=jnp.float8_e4m3fn, | ||
bwd_dtype=jnp.float8_e5m2, | ||
is_2x2x=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, what does this 2x2x mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2x2x = True means outputting 2 copies of quantized data, one is used in the forward, and one is used in the backward.
@mingxu1067 I think you should not specify is_2x2x
here so that the program can decide based on the GPU arch, i.e. only doing 2x on Hopper https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/quantize/quantizer.py#L969
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
out_specs=layer_w0w1_psepc, | ||
w_fsdp_axis="fsdp", w_fsdp_dim=1) | ||
else: | ||
w0_kernel_axes = ("exp", None, "mlp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't recall the reason we put "None" in the embedding dimension. But I think we are good to reuse. Could you help add a comment: "TODO(ranran): check if None should be replaced by "embed_no_exp" for performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
|
||
@functools.partial(jax.custom_vjp, nondiff_argnums=(1,)) | ||
def _pmax_no_backward(x, axis): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if you have tested the gradient? We have some examples in moe_test.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not tested the unit-tests, but we tested on end-2-end deepseek3 model training. The loss curve match with native XLA impl. Will attach to the description shortly.
@Skylion007 @RissyRan will arrange the command and attach it with logs shortly. |
Test command and log are attached. |
Doesn't ici_fsdp_parallelism=4, and ici_expert_parallelism=8 imply 32 GPUs per node with this configuration?! Assuming it's on Blackwell |
@Skylion007 Sorry for the typo. It should be |
Hmm, I hit
with this fixed on 25-08-14:maxtext on main |
@Skylion007 This error is from XLA to rewrite some GEMM to FP8, but this PR would not go into that path for FFN1s and FFN2. Are you using sparse-matmul? |
I was running the baseline on main though, don't know if that changes anything. |
@Skylion007 Could you try 32 nodes with my setup? It seems you are using 128 nodes. Besides, could you also share your XLA_FLAGS? |
What should I set them to? I have them unset. When I disable quantization, I get the following error:
|
2025-09-10 22:10:27.267281: F external/xla/xla/service/gpu/transforms/gemm_rewriter.cc:1384] Check failed: b_contracting_dims[0] == num_batch_dims || b_contracting_dims[0] == num_batch_dims + 1
Suspect might be related to the XLA flags needed? |
Description
[NVIDIA - GPU]
Integrate Transformer Engine/JAX's grouped gemm for a better performance on MoE related training.
For standalone GEMMs performance, TE/JAX grouped-gemm could reach 1.2x faster than the alternative batch-gemm.
Tests
Tested on DeepSeek V3 671B models training with token-dropping + expert-per-token=1 on H100s.
Native JAX/XLA:
This PR:
Command to run:
Checklist
Before submitting this PR, please make sure (put X in square brackets):