Skip to content

[fix][nvbugs/5399355] Fix Lamport buffer clear issue for MNNVL TwoShot Allreduce and add FP16 support. #6237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

timlee0212
Copy link
Collaborator

@timlee0212 timlee0212 commented Jul 21, 2025

Summary by CodeRabbit

Summary by CodeRabbit

  • New Features

    • Added support for half-precision (float16) data type in distributed allreduce operations.
    • Extended tests and functionality to handle multiple sequence lengths simultaneously.
  • Bug Fixes

    • Improved buffer clearing and flag management in distributed allreduce operations to prevent stale data from previous runs.
    • Enhanced synchronization and buffer management for more reliable distributed computation.
  • Chores

    • Updated internal buffer state tracking to ensure accurate handling of varying token counts during distributed operations.
    • Expanded supported data types for distributed allreduce to include float16.

Description

tl;dr:

  • Fix accuracy issue caused by lamport buffer clearing strategy.
  • Fix fallback path issue when MNNVL is not applicable.
  • Add FP16 support for MNNVL two-shot kernel.
  • Enhance the coverage of the unittest for MNNVL twoshot kernel.

The MNNVL twoshot kernel employs three buffers for Lamport synchronization in a circular manner. For each kernel call, it clears the preceding buffer and utilizes the current buffer for communication. In practical scenarios, it is possible that the preceding call had a lower number of tokens compared to the current call, resulting in some elements remaining ambiguous and potentially leading to race conditions for subsequent calls.

This PR addresses this issue by introducing an additional variable in the buffer flags, capturing the number of tokens in the preceding call. Consequently, the kernel will clear the buffer based on this variable rather than the current kernel grid.

Furthermore, this PR resolves an issue that arises when allreduce_strategy is set to MNNVL but the allreduce operation necessitates a fallback. The fallback path fails to acknowledge MNNVL as a valid strategy. Therefore, it is imperative to modify the strategy as well at the fallback level.

Test Coverage

pytest tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py

11.73s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:float32-hidden:7168-seqlen:[128]]
11.57s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:float32-hidden:7168-seqlen:[128]]
10.12s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:float16-hidden:7168-seqlen:[128]]
10.12s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:float16-hidden:7168-seqlen:[32]]
10.10s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:bfloat16-hidden:7168-seqlen:[31, 11, 27, 4]]
10.09s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:bfloat16-hidden:7168-seqlen:[31, 11, 27, 4]]
10.07s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:float16-hidden:7168-seqlen:[31, 11, 27, 4]]
10.06s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:bfloat16-hidden:7168-seqlen:[4]]
10.05s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:float16-hidden:7168-seqlen:[128]]
10.03s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:bfloat16-hidden:7168-seqlen:[128]]
10.02s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:float32-hidden:7168-seqlen:[32]]
10.02s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:float16-hidden:7168-seqlen:[15]]
10.02s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:float32-hidden:7168-seqlen:[1]]
10.01s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:float32-hidden:7168-seqlen:[4]]
10.01s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:bfloat16-hidden:7168-seqlen:[128]]
10.01s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:float16-hidden:7168-seqlen:[32]]
10.01s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:bfloat16-hidden:7168-seqlen:[1]]
10.01s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:float32-hidden:7168-seqlen:[15]]
10.00s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:bfloat16-hidden:7168-seqlen:[15]]
10.00s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:bfloat16-hidden:7168-seqlen:[15]]
10.00s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:float32-hidden:7168-seqlen:[15]]
10.00s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:float16-hidden:7168-seqlen:[1]]
9.98s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:bfloat16-hidden:7168-seqlen:[4]]
9.98s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:float32-hidden:7168-seqlen:[4]]
9.98s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:float32-hidden:7168-seqlen:[32]]
9.97s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:float16-hidden:7168-seqlen:[4]]
9.97s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:float16-hidden:7168-seqlen:[15]]
9.96s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:float32-hidden:7168-seqlen:[1]]
9.95s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:float16-hidden:7168-seqlen:[31, 11, 27, 4]]
9.93s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:bfloat16-hidden:7168-seqlen:[32]]
9.93s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:bfloat16-hidden:7168-seqlen:[32]]
9.90s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:bfloat16-hidden:7168-seqlen:[1]]
9.89s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:float16-hidden:7168-seqlen:[4]]
9.36s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:float32-hidden:7168-seqlen:[31, 11, 27, 4]]
9.29s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[no_fusion-dtype:float32-hidden:7168-seqlen:[31, 11, 27, 4]]
9.26s call     _torch/multi_gpu/test_mnnvl_allreduce.py::test_row_linear_residual_norm_fusion[fusion-dtype:float16-hidden:7168-seqlen:[1]]

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@timlee0212 timlee0212 requested a review from a team as a code owner July 21, 2025 22:39
@timlee0212 timlee0212 requested review from yilin-void and brb-nv July 21, 2025 22:39
Copy link

coderabbitai bot commented Jul 21, 2025

Walkthrough

The changes update buffer clearing and flag management in the MNNVL allreduce CUDA kernels and their Python interface. They introduce a new buffer flag for tracking tokens to clear, adjust buffer clearing logic for correctness, extend support for half-precision data types, and ensure fallback strategies are handled explicitly in the Python layer. The unit tests are enhanced to support multiple sequence lengths simultaneously. No public interfaces are removed.

Changes

File(s) Change Summary
cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu Added LamportFlags struct for buffer flag encapsulation; moved offset pointer in buffer_flags from index 3 to 4; added num_tokens_to_clear to compute clearing workload; expanded buffer clearing loops before/after reduce-broadcast phases; templated half-precision support; updated buffer_flags usage for synchronization; extended kernel launches for half precision.
tensorrt_llm/_torch/distributed/ops.py Extended buffer_flags tensor from 4 to 5 elements by adding num_tokens_to_clear; updated get_supported_dtypes() to include float16; modified AllReduce.forward to fallback from MNNVL to AUTO strategy explicitly when needed.
tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py Generalized tests and helper functions to support multiple sequence lengths simultaneously; updated signatures to accept lists of tensors; reused single AllReduce instance for all sequence lengths; parameterized tests over dtypes and fusion options.

Estimated code review effort

3 (~45 minutes)

Suggested reviewers

  • symphonylyh
  • lucifer1004
  • kaiyux

Poem

🐇 Hopping through kernels deep,
Flags and tokens now we keep.
Half-precision joins the race,
Clearing buffers, no stale trace.
Python’s fallback, strategy true,
Multiple lengths, tests anew.
Allreduce shines, fresh and bright,
CodeRabbit’s work, a pure delight! ✨

✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@timlee0212 timlee0212 self-assigned this Jul 21, 2025
@timlee0212 timlee0212 changed the title [nvbugs 5399355] Fix Lamport buffer clear issue for MNNVL TwoShot Allreduce [nvbugs/5399355] Fix Lamport buffer clear issue for MNNVL TwoShot Allreduce Jul 21, 2025
@timlee0212 timlee0212 changed the title [nvbugs/5399355] Fix Lamport buffer clear issue for MNNVL TwoShot Allreduce [fix][nvbugs/5399355] Fix Lamport buffer clear issue for MNNVL TwoShot Allreduce Jul 21, 2025
@timlee0212
Copy link
Collaborator Author

/bot run

@timlee0212 timlee0212 requested review from kaiyux and zongfeijing July 22, 2025 04:08
@tensorrt-cicd
Copy link
Collaborator

PR_Github #12515 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12515 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #9302 completed with status: 'FAILURE'

@timlee0212
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12586 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12586 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9369 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@timlee0212 timlee0212 changed the title [fix][nvbugs/5399355] Fix Lamport buffer clear issue for MNNVL TwoShot Allreduce [fix][nvbugs/5399355] Fix Lamport buffer clear issue for MNNVL TwoShot Allreduce and add FP16 support. Jul 22, 2025
@timlee0212
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12610 [ run ] triggered by Bot

@timlee0212
Copy link
Collaborator Author

/bot kill

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (1)
tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py (1)

94-153: Remove unused loop variable

Only the unused i needs to be cleaned up—reusing the same AllReduce instance across different sequence lengths is safe, as it doesn’t cache tensor shapes.

• In tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py, replace the loop:

-    for i, (x, residual, reference_output) in enumerate(
-            zip(x_list, residual_list, reference_output_list)):
+    for x, residual, reference_output in zip(
+            x_list, residual_list, reference_output_list):
🧹 Nitpick comments (1)
tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py (1)

188-230: Consider improving readability of the complex list comprehension.

The test logic correctly handles multiple sequence lengths, but the list comprehension in the MPIPoolExecutor call (lines 211-226) is quite complex and could benefit from improved readability.

Consider extracting the data preparation logic for better readability:

+    # Prepare per-rank data for each sequence length
+    per_rank_args = []
+    for i in range(tensor_parallel_size):
+        rank_x_list = [x[i, :, :] for x in x_list]
+        per_rank_args.append((
+            tensor_parallel_size,
+            row_linear_residual_norm_fusion_forward,
+            rank_x_list,
+            residual_list,
+            norm_weight,
+            eps,
+            hidden_size,
+            dtype,
+            fusion,
+            reference_output_list,
+        ))
+
     with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
         results = executor.map(
             run_single_rank,
-            *zip(*[
-                (
-                    tensor_parallel_size,
-                    row_linear_residual_norm_fusion_forward,
-                    [
-                        x[i, :, :] for x in x_list
-                    ],  # Extract the i-th rank's data from each sequence length
-                    residual_list,
-                    norm_weight,
-                    eps,
-                    hidden_size,
-                    dtype,
-                    fusion,
-                    reference_output_list,
-                ) for i in range(tensor_parallel_size)
-            ]),
+            *zip(*per_rank_args),
         )
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3a82827 and b05210a.

📒 Files selected for processing (1)
  • tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py (3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py (2)

47-78: LGTM! Function signature correctly updated for multi-sequence support.

The refactoring from single tensors to lists of tensors is well-implemented and aligns with the goal of supporting multiple sequence lengths simultaneously.


158-179: Excellent parameterization for comprehensive test coverage.

The updated parameterization effectively tests:

  • Multiple sequence lengths simultaneously (including the multi-length case [31, 11, 27, 4])
  • FP16 support as mentioned in the PR objectives
  • Various hidden sizes and fusion options

This provides thorough coverage of the enhanced MNNVL allreduce functionality.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12617 [ kill ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12610 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12617 [ kill ] completed with state SUCCESS
Successfully killed previous jobs for commit b05210a

@timlee0212
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12618 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12618 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #9387 completed with status: 'FAILURE'

@timlee0212
Copy link
Collaborator Author

/bot run

1 similar comment
@timlee0212
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12663 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12663 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9417 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants