Skip to content

Conversation

mingxu1067
Copy link

@mingxu1067 mingxu1067 commented Sep 9, 2025

Description

[NVIDIA - GPU]
Integrate Transformer Engine/JAX's grouped gemm for a better performance on MoE related training.
For standalone GEMMs performance, TE/JAX grouped-gemm could reach 1.2x faster than the alternative batch-gemm.

Tests

Tested on DeepSeek V3 671B models training with token-dropping + expert-per-token=1 on H100s.
Native JAX/XLA:

completed step: 0, seconds: 79.240, TFLOP/s/device: 7.401, Tokens/s/device: 51.691, total_weights: 1048576, loss: 12.269
completed step: 1, seconds: 3.895, TFLOP/s/device: 150.557, Tokens/s/device: 1051.521, total_weights: 1048576, loss: 12.269
completed step: 2, seconds: 3.240, TFLOP/s/device: 180.994, Tokens/s/device: 1264.130, total_weights: 1048576, loss: 12.083
completed step: 3, seconds: 3.268, TFLOP/s/device: 179.443, Tokens/s/device: 1253.257, total_weights: 1048576, loss: 11.585
completed step: 4, seconds: 5.791, TFLOP/s/device: 101.264, Tokens/s/device: 707.279, total_weights: 1048576, loss: 11.146
completed step: 5, seconds: 3.231, TFLOP/s/device: 181.498, Tokens/s/device: 1267.553, total_weights: 1048576, loss: 10.720
completed step: 6, seconds: 3.232, TFLOP/s/device: 181.442, Tokens/s/device: 1267.172, total_weights: 1048576, loss: 10.309
completed step: 7, seconds: 10.892, TFLOP/s/device: 53.840, Tokens/s/device: 376.052, total_weights: 1048576, loss: 9.899
completed step: 8, seconds: 3.214, TFLOP/s/device: 182.458, Tokens/s/device: 1274.493, total_weights: 1048576, loss: 9.523
completed step: 9, seconds: 3.228, TFLOP/s/device: 181.667, Tokens/s/device: 1268.889, total_weights: 1048576, loss: 9.192
completed step: 10, seconds: 3.242, TFLOP/s/device: 180.882, Tokens/s/device: 1263.568, total_weights: 1048576, loss: 8.916
completed step: 11, seconds: 3.226, TFLOP/s/device: 181.779, Tokens/s/device: 1269.811, total_weights: 1048576, loss: 8.694
completed step: 12, seconds: 3.218, TFLOP/s/device: 182.231, Tokens/s/device: 1272.807, total_weights: 1048576, loss: 8.513
completed step: 13, seconds: 3.235, TFLOP/s/device: 181.274, Tokens/s/device: 1266.265, total_weights: 1048576, loss: 8.375
completed step: 14, seconds: 3.237, TFLOP/s/device: 181.162, Tokens/s/device: 1265.440, total_weights: 1048576, loss: 8.270

This PR:

completed step: 0, seconds: 118.801, TFLOP/s/device: 4.936, Tokens/s/device: 34.478, total_weights: 1048576, loss: 12.271
completed step: 1, seconds: 3.626, TFLOP/s/device: 161.724, Tokens/s/device: 1129.609, total_weights: 1048576, loss: 12.269
completed step: 2, seconds: 3.110, TFLOP/s/device: 188.548, Tokens/s/device: 1316.969, total_weights: 1048576, loss: 12.141
completed step: 3, seconds: 3.126, TFLOP/s/device: 187.602, Tokens/s/device: 1310.366, total_weights: 1048576, loss: 11.649
completed step: 4, seconds: 5.819, TFLOP/s/device: 100.770, Tokens/s/device: 703.860, total_weights: 1048576, loss: 11.099
completed step: 5, seconds: 3.128, TFLOP/s/device: 187.503, Tokens/s/device: 1309.672, total_weights: 1048576, loss: 10.574
completed step: 6, seconds: 3.151, TFLOP/s/device: 186.121, Tokens/s/device: 1300.020, total_weights: 1048576, loss: 10.111
completed step: 7, seconds: 12.467, TFLOP/s/device: 47.037, Tokens/s/device: 328.542, total_weights: 1048576, loss: 9.663
completed step: 8, seconds: 3.095, TFLOP/s/device: 189.502, Tokens/s/device: 1323.635, total_weights: 1048576, loss: 9.270
completed step: 9, seconds: 3.098, TFLOP/s/device: 189.260, Tokens/s/device: 1321.942, total_weights: 1048576, loss: 8.937
completed step: 10, seconds: 3.103, TFLOP/s/device: 188.973, Tokens/s/device: 1319.936, total_weights: 1048576, loss: 8.675
completed step: 11, seconds: 3.123, TFLOP/s/device: 187.775, Tokens/s/device: 1311.572, total_weights: 1048576, loss: 8.466
completed step: 12, seconds: 3.124, TFLOP/s/device: 187.726, Tokens/s/device: 1311.231, total_weights: 1048576, loss: 8.299
completed step: 13, seconds: 3.102, TFLOP/s/device: 189.032, Tokens/s/device: 1320.348, total_weights: 1048576, loss: 8.176
completed step: 14, seconds: 3.108, TFLOP/s/device: 188.704, Tokens/s/device: 1318.061, total_weights: 1048576, loss: 8.085

Command to run:

python3 maxtext/MaxText/train.py maxtext/MaxText/configs/base.yml \
    run_name=logdir \
    use_iota_embed=true \
    scan_layers=true \
    steps=15 \
    per_device_batch_size=1 \
    model_name=deepseek3-671b \
    remat_policy=full \
    enable_checkpointing=false \
    logits_dot_in_fp32=false \
    base_output_directory=local_train \
    dataset_path=local \
    dataset_type=synthetic \
    tokenizer_path=maxtext/assets/tokenizer_llama3.tiktoken \
    tokenizer_type=tiktoken \
    n_routing_groups=-1 \
    topk_routing_group=-1 \
    enable_goodput_recording=false \
    monitor_goodput=false \
    max_target_length=4096 \
    attention=cudnn_flash_jax \
    megablox=false \
    hardware=gpu_hardware \
    jax_cache_dir="jax_cache" \
    quantization=fp8 \
    sparse_matmul=false \
    dcn_fsdp_parallelism=1 \
    ici_fsdp_parallelism=4 \
    dcn_data_parallelism=1 \
    ici_data_parallelism=1 \
    ici_tensor_parallelism=1 \
    dcn_expert_parallelism=32 \
    ici_expert_parallelism=2 \
    te_grouped_gemm=true"

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@@ -1,4 +1,5 @@
# Copyright 2023–2025 Google LLC
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kyle-meggs @ultrons - I have no idea what our copyright policies are, curious what your thoughts are

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this copyright comment

@Skylion007
Copy link

Skylion007 commented Sep 9, 2025

What Docker version did you use to get working DeepSeek-v3 model training on GPU? I see OOMs even on 256 H200s when adapting the TPU configuration to GPUs on the initial_state_partial when training from scratch. Also what MFU are you getting with DSV3 through MaxText?

@Skylion007
Copy link

I also needed this PR to get cudnn_attention working with the latest JAX/TransformerEngine: #2303

@mingxu1067
Copy link
Author

We are using 25-08-14 nightly-built JAX container from JAX-Toolbox. Our experiments are under token-dropping + token-per-expert=8.

@Skylion007
Copy link

@mingxu1067 Could you provide the train command you used to test this / run DeepSeekV3 on GPUs?

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! Please attach your tests in the PR description.

@@ -1,4 +1,5 @@
# Copyright 2023–2025 Google LLC
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this copyright comment

scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2,
is_2x2x=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, what does this 2x2x mean?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2x2x = True means outputting 2 copies of quantized data, one is used in the forward, and one is used in the backward.

@mingxu1067 I think you should not specify is_2x2x here so that the program can decide based on the GPU arch, i.e. only doing 2x on Hopper https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/quantize/quantizer.py#L969

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

out_specs=layer_w0w1_psepc,
w_fsdp_axis="fsdp", w_fsdp_dim=1)
else:
w0_kernel_axes = ("exp", None, "mlp")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't recall the reason we put "None" in the embedding dimension. But I think we are good to reuse. Could you help add a comment: "TODO(ranran): check if None should be replaced by "embed_no_exp" for performance.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.



@functools.partial(jax.custom_vjp, nondiff_argnums=(1,))
def _pmax_no_backward(x, axis):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if you have tested the gradient? We have some examples in moe_test.py

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not tested the unit-tests, but we tested on end-2-end deepseek3 model training. The loss curve match with native XLA impl. Will attach to the description shortly.

@mingxu1067
Copy link
Author

@Skylion007 @RissyRan will arrange the command and attach it with logs shortly.

@mingxu1067
Copy link
Author

Test command and log are attached.

@Skylion007
Copy link

Skylion007 commented Sep 10, 2025

Doesn't ici_fsdp_parallelism=4, and ici_expert_parallelism=8 imply 32 GPUs per node with this configuration?! Assuming it's on Blackwell

@mingxu1067
Copy link
Author

@Skylion007 Sorry for the typo. It should be ICI_FSDP=4 x ICI_EP=2 x DCN_EP=32. Thank you for pointing out.

@Skylion007
Copy link

Skylion007 commented Sep 10, 2025

@Skylion007 Sorry for the typo. It should be ICI_FSDP=4 x ICI_EP=2 x DCN_EP=32. Thank you for pointing out.

Hmm, I hit

2025-09-10 20:58:14.694903: F external/xla/xla/service/gpu/transforms/gemm_rewriter.cc:1384] Check failed: b_contracting_dims[0] == num_batch_dims || b_contracting_dims[0] == num_batch_dims + 1
Fatal Python error: Aborted

with this fixed on 25-08-14:maxtext on main

@mingxu1067
Copy link
Author

@Skylion007 This error is from XLA to rewrite some GEMM to FP8, but this PR would not go into that path for FFN1s and FFN2. Are you using sparse-matmul?

@Skylion007
Copy link

Skylion007 commented Sep 10, 2025

@Skylion007 This error is from XLA to rewrite some GEMM to FP8, but this PR would not go into that path for FFN1s and FFN2. Are you using sparse-matmul?

 (core dumped) PYTHONPATH=src/ python3 -m MaxText.train MaxText/configs/base.yml hardware=gpu_multiprocess base_output_directory="/mnt/sharefs/users/runner/git/maxtext/maxtext-outputs/20250910-215610" run_name=gpu_train_test ici_expert_parallelism=2 ici_tensor_parallelism=1 ici_fsdp_parallelism=4 ici_data_parallelism=1 dcn_expert_parallelism=32 dcn_tensor_parallelism=1 dcn_fsdp_parallelism=2 dcn_data_parallelism=2 use_iota_embed=true scan_layers=true steps=15 per_device_batch_size=1 model_name=deepseek3-671b remat_policy=full enable_checkpointing=false logits_dot_in_fp32=false base_output_directory=local_train dataset_path=local dataset_type=synthetic tokenizer_path=maxtext/assets/tokenizer_llama3.tiktoken tokenizer_type=tiktoken n_routing_groups=-1 topk_routing_group=-1 enable_goodput_recording=false monitor_goodput=false max_target_length=4096 attention=cudnn_flash_jax megablox=false jax_cache_dir="jax_cache" quantization=fp8 sparse_matmul=false

I was running the baseline on main though, don't know if that changes anything.

@mingxu1067
Copy link
Author

@Skylion007 Could you try 32 nodes with my setup? It seems you are using 128 nodes. Besides, could you also share your XLA_FLAGS?

@Skylion007
Copy link

@Skylion007 Could you try 32 nodes with my setup? It seems you are using 128 nodes. Besides, could you also share your XLA_FLAGS?

What should I set them to? I have them unset.

When I disable quantization, I get the following error:

  File "/opt/jax/jax/_src/compiler.py", line 378, in backend_compile_and_load
IndexError: absl::container_internal::raw_hash_map<>::at

@Skylion007
Copy link

@Skylion007 Could you try 32 nodes with my setup? It seems you are using 128 nodes. Besides, could you also share your XLA_FLAGS?

2025-09-10 22:10:27.267281: F external/xla/xla/service/gpu/transforms/gemm_rewriter.cc:1384] Check failed: b_contracting_dims[0] == num_batch_dims || b_contracting_dims[0] == num_batch_dims + 1
Fatal Python error: Aborted

sr/bin/bash: line 71: 2200155 Aborted                 (core dumped) PYTHONPATH=src/ python3 -m MaxText.train MaxText/configs/base.yml hardware=gpu_multiprocess base_output_directory="/mnt/sharefs/users/runner/git/maxtext/maxtext-outputs/20250910-220910" run_name=gpu_train_test ici_expert_parallelism=2 ici_tensor_parallelism=1 ici_fsdp_parallelism=4 ici_data_parallelism=1 dcn_expert_parallelism=32 dcn_tensor_parallelism=1 dcn_fsdp_parallelism=1 dcn_data_parallelism=1 use_iota_embed=true scan_layers=true steps=15 per_device_batch_size=1 model_name=deepseek3-671b remat_policy=full enable_checkpointing=false logits_dot_in_fp32=false base_output_directory=local_train dataset_path=local dataset_type=synthetic tokenizer_path=maxtext/assets/tokenizer_llama3.tiktoken tokenizer_type=tiktoken n_routing_groups=-1 topk_routing_group=-1 enable_goodput_recording=false monitor_goodput=false max_target_length=4096 attention=cudnn_flash_jax megablox=false jax_cache_dir="jax_cache" quantization=fp8 sparse_matmul=false

Suspect might be related to the XLA flags needed?

@mingxu1067 mingxu1067 changed the title Integrate TE/JAX GroupedGEMM for MoE. [Draft] Integrate TE/JAX GroupedGEMM for MoE. Sep 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants