Skip to content

Conversation

lfr-0531
Copy link
Collaborator

@lfr-0531 lfr-0531 commented Aug 12, 2025

Summary by CodeRabbit

  • New Features

    • Data-parallel–aware MoE token limits (dp_size-based) and a public dp_size accessor.
    • Chunked MoE execution for very large token counts with overlapped processing.
    • Public workspace management to allocate and reuse preallocated buffers.
  • Performance Improvements

    • Fewer internal allocations via caller-provided output/scale buffers and in-place workspaces.
    • Better latency hiding using multiple CUDA streams and overlap events.
    • Optimized FP8 group quantization path.
  • API Changes

    • Several MoE/quantization methods now accept caller-provided output/scale buffers and workspace handles.

Description

  • Pre-allocate workspaces for DeepGEMM MoE to avoid frequent cudaFree/cudaMalloc calls.
  • Set a default moe_max_num_tokens = 18688 to DeepGEMM MoE to avoid OOM. If the num_tokens > 2 * moe_max_num_tokens, we will enable chunked moe.

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@lfr-0531 lfr-0531 requested review from litaotju and yuxianq August 12, 2025 05:11
@lfr-0531 lfr-0531 requested a review from a team as a code owner August 12, 2025 05:11
Copy link
Contributor

coderabbitai bot commented Aug 12, 2025

📝 Walkthrough

Walkthrough

Replaces world_size/use_dp gating with a dp_size-derived MoE token cap (Mapping.dp_size). DeepGemm fused MoE was reworked into a workspace-driven, chunked, DP-aware backend (new get_workspace, forward, workspace-backed forward_chunk, modified GEMM/quant signatures). FP8 post-quant now writes into caller-provided buffers and returns only scales.

Changes

Cohort / File(s) Summary
MoE DP token-cap refactor
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py, tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py, tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Replace world_size/use_dp gating with unconditional scaling by mapping.dp_size; compute moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size and set self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens; update aux-stream/event gating to compare against the dp_size baseline. No public signatures changed.
DeepGemm MoE workspace & chunking overhaul
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
Add get_workspace(m_max, group_size); add top-level forward for chunked/DP-aware execution; extend forward_chunk(..., workspace); add set_strides(workspace, g, m, k) helper; change FP8/GEMM wrappers to accept preallocated out buffers (output, output_s, d), remove internal allocations; introduce two CUDA streams and overlap event for chunking; clamp default moe_max_num_tokens to a hard max. Public API surface expanded/updated.
Mapping DP size
tensorrt_llm/mapping.py
Add @property dp_size returning tp_size when enable_attention_dp is True, otherwise 1.
FP8 post-quant API change
tensorrt_llm/quantization/utils/fp8_utils.py
silu_and_mul_masked_post_quant_fwd now accepts preallocated output and output_scale buffers, writes in-place, and returns only output_scale (signature and return changed).
Quant/GEMM wrapper signatures
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py, tensorrt_llm/quantization/utils/fp8_utils.py
masked_index_copy_group_quant_fp8 and low-level GEMM wrapper signatures updated to take output/scale buffers (no internal allocation) and to use caller-provided outputs.

Sequence Diagram(s)

sequenceDiagram
  participant Client
  participant DeepGemmMoE as DeepGemmFusedMoE
  participant Stream0
  participant Stream1
  participant Collectives as DP_Collectives

  Client->>DeepGemmMoE: forward(x, router_logits, ...)
  DeepGemmMoE->>DeepGemmMoE: compute token_count, moe_max_num_tokens
  alt token_count > threshold
    loop per chunk
      DeepGemmMoE->>DeepGemmMoE: get_workspace(...)
      DeepGemmMoE->>Stream0: forward_chunk(chunk0, workspace0)
      Stream0-->>DeepGemmMoE: partial outputs/scales
      DeepGemmMoE->>Stream1: forward_chunk(chunk1, workspace1) (overlap)
    end
    DeepGemmMoE->>Collectives: reducescatter/all_reduce (if DP)
    Collectives-->>DeepGemmMoE: reduced pieces
    DeepGemmMoE->>DeepGemmMoE: assemble final output
  else
    DeepGemmMoE->>Stream0: forward_chunk(all, workspace)
    Stream0-->>DeepGemmMoE: output
  end
  DeepGemmMoE-->>Client: final output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

SW Architecture

Suggested reviewers

  • yizhang-nv
  • liji-nv
  • jinyangyuan-nvidia
  • yilin-void
  • yuantailing
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@lfr-0531 lfr-0531 force-pushed the user/fanrongl/add_workspace_for_dg_moe branch from 4a71d56 to 2f94ddf Compare August 12, 2025 05:17
@lfr-0531
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14910 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🔭 Outside diff range comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)

89-160: Critical: Non-contiguous output buffer after transpose operation.

The function modifies the output_s parameter in-place with a transpose operation at line 159, which makes it non-contiguous. This violates the expected contract that the output buffer should remain contiguous for subsequent operations.

Apply this diff to fix the issue:

-    output_s = output_s.transpose(1, 2)[:, :col_size, :]
-    return output_s
+    output_s_transposed = output_s.transpose(1, 2)[:, :col_size, :].contiguous()
+    return output_s_transposed
🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (3)

492-510: Extract magic numbers into named constants.

The code uses several magic numbers (128, 4) for alignment and padding without explanation. These should be extracted into named constants for better maintainability.

Add constants at the class level:

class DeepGemmFusedMoE(CutlassFusedMoE):
    # FP8 alignment requirements
    FP8_TOKEN_ALIGNMENT = 128
    FP8_SCALE_ALIGNMENT = 4
    FP8_GROUP_SIZE = 128

Then update the code:

-        m_max = fp8_utils.align(x.shape[0], 128)
+        m_max = fp8_utils.align(x.shape[0], self.FP8_TOKEN_ALIGNMENT)
         act_input_fp8 = set_strides(workspace["workspace_0"],
                                     self.expert_size_per_partition, m_max,
                                     self.hidden_size)

-        m_padded = fp8_utils.align(m_max, 4)
-        scale_k = fp8_utils.ceil_div(self.hidden_size, 128)
-        scale_k_padded = fp8_utils.align(scale_k, 4)
+        m_padded = fp8_utils.align(m_max, self.FP8_SCALE_ALIGNMENT)
+        scale_k = fp8_utils.ceil_div(self.hidden_size, self.FP8_GROUP_SIZE)
+        scale_k_padded = fp8_utils.align(scale_k, self.FP8_SCALE_ALIGNMENT)

608-610: Improve conditional logic for clarity.

The nested condition for determining num_chunks could be simplified for better readability.

-        num_chunks = 1
-        if num_rows > self.moe_max_num_tokens * 2:
-            num_chunks = (num_rows + self.moe_max_num_tokens -
-                          1) // self.moe_max_num_tokens
+        # Calculate number of chunks needed (minimum 1)
+        num_chunks = max(1, (num_rows + self.moe_max_num_tokens - 1) // self.moe_max_num_tokens)
+        # Only split if we need more than 2 chunks (to use dual-stream overlap)
+        if num_chunks <= 2:
+            num_chunks = 1

663-712: Consider refactoring the chunked execution logic.

The chunked execution logic with stream overlapping is complex and contains repeated patterns. Consider extracting this into a separate method for better maintainability and testability.

Extract the stream-based chunked execution into a helper method:

def _execute_chunked_with_overlap(self, x_list, router_logits_list, 
                                  all_rank_num_tokens_list, use_dp_padding,
                                  workspace_0, workspace_1):
    """Execute chunked MoE with stream overlap for latency hiding."""
    outputs_list = []
    
    for idx_chunk, (x, router_logits) in enumerate(zip(x_list, router_logits_list)):
        workspace = workspace_0 if idx_chunk % 2 == 0 else workspace_1
        stream = self.aux_stream if idx_chunk % 2 == 0 else None
        
        # Forward pass (potentially on aux stream)
        if stream:
            with torch.cuda.stream(stream):
                outputs = self._forward_chunk_wrapper(
                    x, router_logits, idx_chunk, 
                    all_rank_num_tokens_list, use_dp_padding, workspace)
        else:
            outputs = self._forward_chunk_wrapper(
                x, router_logits, idx_chunk,
                all_rank_num_tokens_list, use_dp_padding, workspace)
        
        # Reduce previous chunk (overlap with current forward)
        if idx_chunk > 0:
            self._reduce_previous_chunk(outputs_list, idx_chunk - 1,
                                       all_rank_num_tokens_list, use_dp_padding,
                                       stream is None)
        
        outputs_list.append(outputs)
    
    # Final reduction
    self._reduce_final_chunk(outputs_list, num_chunks, 
                            all_rank_num_tokens_list, use_dp_padding)
    
    return outputs_list
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4a71d56 and 2f94ddf.

📒 Files selected for processing (6)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (10 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1 hunks)
  • tensorrt_llm/mapping.py (1 hunks)
  • tensorrt_llm/quantization/utils/fp8_utils.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • tensorrt_llm/mapping.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
  • tensorrt_llm/quantization/utils/fp8_utils.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
🧠 Learnings (1)
📚 Learning: 2025-08-09T20:57:04.067Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.067Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.

Applied to files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py (1)

85-86: LGTM! Clean implementation of DP-aware token scaling.

The unconditional DP-size-based token scaling is correctly implemented, removing the previous conditional logic based on world_size and use_dp. The fallback pattern using the or operator is clean and aligns with the same pattern used in other MoE implementations across the PR.

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1)

154-157: LGTM! Consistent DP-size-based token limit calculation.

The changes correctly implement DP-size-based token scaling, matching the pattern used in other MoE modules. The auxiliary stream initialization is properly gated on the comparison between the configured and calculated token limits.

@lfr-0531 lfr-0531 requested a review from Barry-Delaney August 12, 2025 06:11
@tensorrt-cicd
Copy link
Collaborator

PR_Github #14910 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11254 completed with status: 'FAILURE'

lfr-0531 and others added 5 commits August 12, 2025 18:04
Signed-off-by: Fanrong Li <[email protected]>
Signed-off-by: Fanrong Li <[email protected]>
@lfr-0531 lfr-0531 force-pushed the user/fanrongl/add_workspace_for_dg_moe branch from 2f94ddf to d6150c5 Compare August 12, 2025 10:04
@lfr-0531
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14953 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2f94ddf and d6150c5.

📒 Files selected for processing (6)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (10 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py (1 hunks)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py (1 hunks)
  • tensorrt_llm/mapping.py (1 hunks)
  • tensorrt_llm/quantization/utils/fp8_utils.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py
  • tensorrt_llm/quantization/utils/fp8_utils.py
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a Python file, prefer docstrings over comments.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the class docstring.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py
  • tensorrt_llm/mapping.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py
  • tensorrt_llm/mapping.py
🧠 Learnings (1)
📚 Learning: 2025-08-09T20:57:04.067Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.067Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.

Applied to files:

  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py (1)

85-86: Ignore incorrect NoneType/or concerns for moe_max_num_tokens

ModelConfig.max_num_tokens is defined as an int (default 8192) in the _torch dataclass, so it cannot be None at initialization. Using or to fall back when moe_max_num_tokens is unset (i.e. None) is valid, and an explicit 0 value for moe_max_num_tokens would likewise be a non‐standard configuration that our APIs do not expect.

No changes required here; please proceed.

Likely an incorrect or invalid review comment.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #14953 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11287 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@yuxianq yuxianq merged commit 1bbc0e3 into NVIDIA:main Aug 13, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants