Skip to content

[TRTLLM-5627] feat: Implement pytorch sampler for MTP #6245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

nvxuanyuc
Copy link
Collaborator

@nvxuanyuc nvxuanyuc commented Jul 22, 2025

[feat] Implement pytorch sampler for MTP

Description

  • Adds support for advanced sampling in the PyTorch path for MTP with speculative decoding
    • Previously, only greedy mode was supported.
    • Implements temperature, top-p, top-k, and min-p sampling parameters in Python when using MTP speculative decoding (for DeepSeek) @pathorn [#5627]
    • Adds support for returning log-probs from the Pytorch sampler related to [#5620]

The default behavior of the MTP pytorch decoder remains greedy sampling. Advanced sampling can be enabled via the enable_mixed_sampler flag in TorchLlmArgs.

Test Coverage

  • Added test for greedy mode temperature <= 1e-2 using the new PyTorch sampler.
  • Tests for advanced sampling modes are not yet included.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@nvxuanyuc nvxuanyuc requested review from a team as code owners July 22, 2025 03:37
Copy link
Contributor

coderabbitai bot commented Jul 22, 2025

📝 Walkthrough

Walkthrough

This change introduces advanced multi-token prediction (MTP) sampling support in the PyTorch execution engine, enabling per-token sampling parameter control for speculative decoding. It adds new fields and methods to propagate and utilize sampling parameters (temperature, top-k, top-p, min-p) throughout the model engine, speculative metadata, and MTP worker. A new batch sampling function and corresponding tests are also included.

Changes

Cohort / File(s) Change Summary
Advanced MTP Sampler Integration
tensorrt_llm/_torch/pyexecutor/model_engine.py
Adds advanced MTP sampler support: tracks per-request sampling params, propagates them for speculative decoding, and updates control flow to handle advanced sampler mode.
PyTorch-native Sampling Utilities
tensorrt_llm/_torch/pyexecutor/sampler.py
Introduces modular PyTorch-native sampling functions (temperature, top-k, top-p, min-p, batch sampling) and a unified batch sampling interface.
Speculative Metadata Extensions
tensorrt_llm/_torch/speculative/interface.py
Adds optional tensor fields to SpecMetadata for storing per-request sampling parameters.
MTP Worker and Metadata Updates
tensorrt_llm/_torch/speculative/mtp.py
Adds CUDA tensor fields and setup/update methods to MTPSpecMetadata, enables advanced sampling in MTPWorker when configured, and integrates the new batch sampler.
Model Config Flag
tensorrt_llm/_torch/model_config.py
Adds enable_mixed_sampler boolean field to ModelConfig dataclass, defaulting to False.
Advanced Sampler Unit Test
tests/unittest/_torch/speculative/test_mtp.py
Adds a parameterized unit test for sample_and_accept_draft_tokens with the advanced PyTorch sampler in greedy mode, verifying correct behavior with deterministic settings.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant PyTorchModelEngine
    participant MTPWorker
    participant SpecMetadata
    participant Sampler

    User->>PyTorchModelEngine: forward(requests, ...)
    PyTorchModelEngine->>PyTorchModelEngine: _prepare_tp_inputs()
    PyTorchModelEngine->>SpecMetadata: update_advanced_mtp_sampling_params(...)
    PyTorchModelEngine->>SpecMetadata: _set_up_advanced_mtp_sampling(...)
    PyTorchModelEngine->>MTPWorker: sample_and_accept_draft_tokens(input_ids, logits, spec_metadata, ...)
    alt enable_mixed_sampler
        MTPWorker->>Sampler: sampling_batch(logits, temperatures, top_k, top_p, min_p)
        Sampler-->>MTPWorker: sampled_tokens, log_probs
    else
        MTPWorker->>MTPWorker: greedy_sample(logits)
    end
    MTPWorker-->>PyTorchModelEngine: accepted_tokens
    PyTorchModelEngine-->>User: output
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested reviewers

  • HuiGao-NV
  • Funatiq

Note

⚡️ Unit Test Generation is now available in beta!

Learn more here, or try it out under "Finishing Touches" below.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
tests/unittest/_torch/speculative/test_mtp.py (1)

370-370: Fix line length violation.

The line exceeds the 120-character limit as flagged by static analysis.

-            # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
+            # sampling default config vals set in 
+            # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

264-267: Verify the global impact of setting torch.manual_seed(0).

Setting a global PyTorch seed in the constructor could have unintended side effects on other operations. Consider:

  1. This affects all PyTorch random operations, not just sampling
  2. It might interfere with user-controlled randomness
  3. Consider using a local generator instead of global seed

Consider using a dedicated random generator for sampling operations:

-        # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG
-        # operations that avoid torch.multinomial's CPU-GPU sync overhead
-        torch.manual_seed(0)
+        # Create dedicated generator for consistent multi-GPU sampling
+        # to avoid torch.multinomial's CPU-GPU sync overhead
+        self.sampling_generator = torch.Generator(device='cuda')
+        self.sampling_generator.manual_seed(0)

Then pass this generator to sampling operations that need deterministic behavior.


1163-1195: Consider moving helper functions to class level for better organization.

These helper functions are defined inside _prepare_tp_inputs but could be reused elsewhere. Consider making them class methods or static methods.

Move these functions to class level:

-        def get_request_temperature(request: LlmRequest) -> float:
-            if not request.sampling_config.temperature:
-                return 0.7
-            temperature = request.sampling_config.temperature[0]
-            if 0 < temperature < 1e-2:
-                # temperature less than 0.01 may cause numerical errors
-                temperature = 0.01
-            return temperature
+    @staticmethod
+    def _get_request_temperature(request: LlmRequest) -> float:
+        if not request.sampling_config.temperature:
+            return 0.7
+        temperature = request.sampling_config.temperature[0]
+        if 0 < temperature < 1e-2:
+            # temperature less than 0.01 may cause numerical errors
+            temperature = 0.01
+        return temperature

Apply similar changes to the other helper functions and update the call sites accordingly.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fddb7f1 and cec9318.

📒 Files selected for processing (8)
  • examples/llm-api/quickstart_advanced.py (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (11 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (1 hunks)
  • tensorrt_llm/llmapi/tokenizer.py (1 hunks)
  • tests/unittest/_torch/speculative/test_mtp.py (1 hunks)
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.

🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.

🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

🔇 Additional comments (20)
tensorrt_llm/_torch/speculative/interface.py (1)

135-142: LGTM! Clean addition of sampling parameter fields.

The new optional tensor fields for sampling parameters (temperatures, top_k, top_p, min_p) are well-structured and follow the existing pattern in the SpecMetadata dataclass. The type annotations and comments are clear and appropriate.

examples/llm-api/quickstart_advanced.py (2)

115-117: LGTM! Clean addition of command-line argument.

The new --use_advanced_mtp_sampler flag follows the established pattern for boolean command-line arguments with an appropriate default value.


169-170: LGTM! Proper integration of the new flag.

The use_advanced_mtp_sampler parameter is correctly passed to the MTPDecodingConfig constructor, maintaining consistency with the command-line argument.

tensorrt_llm/_torch/speculative/mtp.py (2)

11-11: LGTM! Appropriate import addition.

The import of sampling_batch function is correctly added to support the advanced MTP sampler functionality.


825-833: LGTM! Well-structured conditional sampling logic.

The implementation demonstrates good practices:

  1. Backward compatibility: Maintains the existing greedy sampling as the default behavior
  2. Clear conditional logic: The flag-based switching is easy to understand and maintain
  3. Future-proofing: Acknowledges the unused target_log_probs for future log probability support
  4. Clean integration: The advanced sampler integrates seamlessly with the existing acceptance algorithm

The approach minimizes risk while enabling the new advanced sampling functionality.

tensorrt_llm/llmapi/llm_args.py (1)

417-422: LGTM! Configuration improvements enhance usability.

The changes improve the MTPDecodingConfig class by:

  1. Making several fields optional with sensible conservative defaults
  2. Adding the new use_advanced_mtp_sampler flag to enable the advanced sampling feature
  3. Following consistent patterns with other boolean configuration flags

The default values are appropriate:

  • num_nextn_predict_layers=1 maintains backward compatibility
  • Boolean flags default to False for conservative behavior
  • relaxed_topk=1 and relaxed_delta=0. provide safe starting points

This provides a clean API where users can enable advanced sampling by simply setting use_advanced_mtp_sampler=True without having to specify all the other parameters.

tests/unittest/_torch/speculative/test_mtp.py (1)

333-401: LGTM! Good test coverage for advanced MTP sampler in greedy mode.

The test implementation correctly validates the advanced PyTorch sampler functionality with proper setup of sampling parameters to enforce greedy behavior. The deterministic seeding and reuse of existing test cases ensures consistency and reproducibility.

However, note that this test only covers greedy mode (temperature ≤ 0.01). Consider adding future tests for actual advanced sampling modes (temperature > 0.01) to validate the full functionality of the advanced sampler.

tensorrt_llm/_torch/pyexecutor/model_engine.py (5)

20-20: LGTM!

The import is necessary for accessing sampling configurations from request objects.


284-286: LGTM!

Clear and logical detection of advanced MTP sampler mode.


382-398: LGTM!

Appropriate CUDA tensor allocations for sampling parameters with correct sizes and data types.


1229-1234: LGTM!

Sampling parameters are correctly collected and replicated for each token position across different request types.

Also applies to: 1317-1326, 1356-1365, 1398-1407


1511-1526: LGTM!

Efficient non-blocking CUDA tensor copies and proper assignment to spec_metadata for advanced MTP sampling.

Also applies to: 1601-1607

tensorrt_llm/_torch/pyexecutor/sampler.py (8)

4-4: LGTM: Clean import addition

The additional typing imports are necessary for the new type annotations in the sampling functions.


154-167: LGTM: Well-implemented sampling function

The function correctly implements top-k and top-p filtering with efficient in-place operations. The use of custom random sampling to avoid CPU-GPU synchronization is a good performance optimization.


169-178: LGTM: Clever sampling implementation

This function uses the Gumbel-max trick effectively to avoid CPU-GPU synchronization. The mathematical approach is sound and the performance justification is clear.


180-198: LGTM: Correct min-p implementation

The adaptive probability thresholding logic is correctly implemented, using the standard approach of scaling min_p by the maximum probability per sequence.


200-232: LGTM: Comprehensive top-k/top-p implementation

The function correctly implements both top-k and top-p filtering with proper handling of edge cases like the "at least one" guarantee. The sorting and scatter approach ensures correctness.


234-236: LGTM: Simple and correct greedy sampling

Clean implementation using argmax with proper tensor reshaping.


238-244: LGTM: Efficient temperature scaling

Correct implementation with efficient in-place operations and proper broadcasting.


246-264: LGTM: Well-designed batch sampling function

This function effectively combines all sampling techniques with proper handling of greedy vs. random sampling. The temperature threshold logic and log-probability calculation are correctly implemented.

@nvxuanyuc nvxuanyuc requested review from jhaotingc and amukkara July 22, 2025 16:23
@nvxuanyuc nvxuanyuc force-pushed the deepinfra_mtp_torch_sampler branch 2 times, most recently from 4a68f67 to 84d09a0 Compare July 22, 2025 16:34
@nvxuanyuc
Copy link
Collaborator Author

/bot run --disable-fail-fast

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
tests/unittest/_torch/speculative/test_mtp.py (1)

370-370: Fix line length violation.

The comment line exceeds the 120-character limit flagged by static analysis.

-            # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
+            # sampling default config vals set in 
+            # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

264-266: Consider making the deterministic seed configurable.

The hardcoded seed value of 0 ensures consistent multi-GPU sampling, but consider making this configurable through the PyTorchConfig to provide flexibility for different use cases while maintaining the performance benefits of avoiding CPU-GPU synchronization.

-        # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG
-        # operations that avoid torch.multinomial's CPU-GPU sync overhead
-        torch.manual_seed(0)
+        # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG
+        # operations that avoid torch.multinomial's CPU-GPU sync overhead
+        seed = getattr(pytorch_backend_config, 'sampling_seed', 0)
+        torch.manual_seed(seed)

1163-1194: LGTM: Well-designed helper functions with proper edge case handling.

The helper functions correctly extract sampling parameters with appropriate defaults and constraints. The temperature clamping to avoid numerical errors and top_k max value handling are particularly well thought out.

Consider extracting the magic numbers to constants:

+TEMPERATURE_MIN_THRESHOLD = 1e-2
+TEMPERATURE_MIN_VALUE = 0.01
+TOP_K_DISABLED_VALUE = 2147483647  # Max int32

 def get_request_temperature(request: LlmRequest) -> float:
     if not request.sampling_config.temperature:
         return 0.7
     temperature = request.sampling_config.temperature[0]
-    if 0 < temperature < 1e-2:
-        temperature = 0.01
+    if 0 < temperature < TEMPERATURE_MIN_THRESHOLD:
+        temperature = TEMPERATURE_MIN_VALUE
     return temperature
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cec9318 and 84d09a0.

📒 Files selected for processing (8)
  • examples/llm-api/quickstart_advanced.py (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (11 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (1 hunks)
  • tensorrt_llm/llmapi/tokenizer.py (1 hunks)
  • tests/unittest/_torch/speculative/test_mtp.py (1 hunks)
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.703Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.

🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

🚧 Files skipped from review as they are similar to previous changes (6)
  • tensorrt_llm/llmapi/tokenizer.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • examples/llm-api/quickstart_advanced.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.703Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.

🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (8)
tests/unittest/_torch/speculative/test_mtp.py (2)

333-401: LGTM! Well-structured test for the advanced MTP sampler.

The new test method effectively validates that the advanced PyTorch sampler produces identical results to the standard sampler when configured for greedy mode. The test design is solid:

  • Proper parameterization reusing existing test cases
  • Deterministic seeding for reproducible results
  • Correct configuration of sampling parameters to enforce greedy mode (temperature ≤ 0.01)
  • Appropriate assertions matching the reference implementation

369-374: Greedy sampling parameters verified

Confirmed that in tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_temperature, any temperature below 0.01 is clamped up to 0.01. Therefore, extending temperatures with 0.01 correctly enforces the intended greedy sampling boundary. No changes required.

tensorrt_llm/_torch/pyexecutor/model_engine.py (6)

20-20: LGTM: Import addition is necessary for new functionality.

The LlmRequest import is properly placed and required for accessing sampling configuration in the advanced MTP sampler.


284-285: LGTM: Correct logic for advanced MTP sampler detection.

The boolean flag correctly identifies when the advanced MTP sampler should be active by checking all necessary conditions in the proper sequence.


382-398: LGTM: Proper CUDA tensor allocation for sampling parameters.

The tensor allocation correctly sizes buffers for batch_size × (max_draft_len + 1) elements, uses appropriate data types, and efficiently allocates only when the advanced sampler is enabled.


1157-1161: LGTM: Correct parameter replication for draft tokens.

The sampling parameter lists are properly initialized and populated with the correct replication pattern for each request type, ensuring parameters are available for both the main token and all draft tokens.

Also applies to: 1229-1233, 1318-1326, 1357-1365, 1399-1407


1512-1526: LGTM: Efficient CUDA tensor copying and metadata integration.

The implementation uses pinned memory and non-blocking copies for optimal performance, properly slices tensors to match actual usage, and cleanly integrates with the existing speculative decoding metadata structure.

Also applies to: 1602-1607


264-266: Excellent implementation of advanced MTP sampler support.

The changes successfully add support for advanced sampling parameters in MTP speculative decoding with:

  • Proper memory management through pre-allocated CUDA tensors
  • Efficient parameter extraction and replication logic
  • Clean integration with existing speculative decoding infrastructure
  • Good performance considerations (pinned memory, non-blocking transfers)
  • Appropriate conditional activation based on configuration flags

The implementation follows established patterns in the codebase and maintains backward compatibility while adding the new functionality.

Also applies to: 284-285, 382-398, 1157-1607

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12593 [ run ] triggered by Bot

@jhaotingc jhaotingc requested a review from netanel-haber July 22, 2025 16:57
@jhaotingc
Copy link
Collaborator

Hi @netanel-haber, @nvxuanyuc has addressed the comments in PR 5627, would you mind reviewing this PR again? Thanks.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12593 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9371 completed with status: 'FAILURE'

@nvxuanyuc nvxuanyuc force-pushed the deepinfra_mtp_torch_sampler branch from 84d09a0 to 7f31555 Compare July 23, 2025 03:10
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/unittest/_torch/speculative/test_mtp.py (1)

370-370: Fix line length violation.

The line exceeds the 120-character limit enforced by the linter.

-            # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
+            # sampling default config vals set in 
+            # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 84d09a0 and 7f31555.

📒 Files selected for processing (8)
  • examples/llm-api/quickstart_advanced.py (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (11 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (1 hunks)
  • tensorrt_llm/llmapi/tokenizer.py (1 hunks)
  • tests/unittest/_torch/speculative/test_mtp.py (1 hunks)
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

✅ Files skipped from review due to trivial changes (2)
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/llmapi/llm_args.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • examples/llm-api/quickstart_advanced.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/llmapi/tokenizer.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tests/unittest/_torch/speculative/test_mtp.py (1)

333-401: LGTM! Well-structured test for the advanced MTP sampler.

The test correctly validates the advanced PyTorch sampler functionality by:

  • Using deterministic seeding for reproducible results
  • Properly configuring sampling parameters to enforce greedy mode (temperature ≤ 0.01)
  • Reusing existing test cases to ensure consistent behavior with the default sampler
  • Following the established test patterns in the codebase

The sampling parameter configuration looks correct for greedy mode testing.

@nvxuanyuc nvxuanyuc force-pushed the deepinfra_mtp_torch_sampler branch from 7f31555 to 607dbc5 Compare July 23, 2025 04:10
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/unittest/_torch/speculative/test_mtp.py (1)

370-370: Fix line length violation.

The line exceeds the 120 character limit. Consider breaking it into multiple lines for better readability.

-            # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
+            # sampling default config vals set in
+            # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7f31555 and 607dbc5.

📒 Files selected for processing (8)
  • examples/llm-api/quickstart_advanced.py (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (11 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (1 hunks)
  • tensorrt_llm/llmapi/tokenizer.py (1 hunks)
  • tests/unittest/_torch/speculative/test_mtp.py (1 hunks)
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

✅ Files skipped from review due to trivial changes (2)
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/speculative/interface.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • tensorrt_llm/_torch/speculative/mtp.py
  • examples/llm-api/quickstart_advanced.py
  • tensorrt_llm/llmapi/tokenizer.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tests/unittest/_torch/speculative/test_mtp.py (2)

333-401: Well-structured test for advanced MTP sampler in greedy mode.

The test method is properly implemented with correct parameterization, deterministic seed setting, and appropriate sampling parameter configuration to enforce greedy mode behavior. The test logic follows the established patterns and should effectively validate the advanced sampler functionality.


363-386: Sampling Parameter Threshold Confirmed

The model engine clamps any non-zero temperature below 0.01 up to 0.01 and treats temperatures ≤ 0.01 as greedy mode. Your test’s use of temperature = 0.01 correctly hits that boundary.

No changes required.

@nvxuanyuc
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12650 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12650 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9404 completed with status: 'FAILURE'

@nvxuanyuc nvxuanyuc force-pushed the deepinfra_mtp_torch_sampler branch from 607dbc5 to 2953667 Compare July 23, 2025 16:37
@nvxuanyuc
Copy link
Collaborator Author

/bot run --disable-fail-fast

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tests/unittest/_torch/speculative/test_mtp.py (2)

370-370: Fix line length violation.

The line exceeds the 120 character limit flagged by the linter.

-            # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
+            # sampling default config vals set in 
+            # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]

388-401: Test execution looks correct but consider broader test coverage.

The test execution properly validates that the advanced sampler produces the same results as the original implementation in greedy mode, which is the expected behavior.

However, this test only covers greedy mode. Consider adding tests for the actual advanced sampling modes (temperature > 0.01, top-k < max_int, etc.) to fully validate the new functionality.

Would you like me to help generate additional test cases for non-greedy sampling modes to improve coverage of the advanced sampler functionality?

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 607dbc5 and 2953667.

📒 Files selected for processing (8)
  • examples/llm-api/quickstart_advanced.py (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (11 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (1 hunks)
  • tensorrt_llm/llmapi/tokenizer.py (1 hunks)
  • tests/unittest/_torch/speculative/test_mtp.py (1 hunks)
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

✅ Files skipped from review due to trivial changes (1)
  • tensorrt_llm/_torch/speculative/interface.py
🚧 Files skipped from review as they are similar to previous changes (6)
  • tensorrt_llm/llmapi/llm_args.py
  • examples/llm-api/quickstart_advanced.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/llmapi/tokenizer.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tests/unittest/_torch/speculative/test_mtp.py (3)

333-340: LGTM! Good test structure and deterministic setup.

The test method signature follows the existing pattern and the deterministic seed ensures consistent behavior across runs, which is important for multi-GPU sampling scenarios.


342-346: Correct configuration for advanced sampler testing.

The test properly enables the advanced MTP sampler feature through the use_advanced_mtp_sampler=True flag, which is the key differentiator from the original test method.


363-387: Well-implemented parameter setup for greedy sampling mode.

The sampling parameters are correctly configured to enforce greedy behavior:

  • Temperature set to 0.01 (at the greedy boundary)
  • top_k set to max int value (no filtering)
  • top_p set to 1.0 (no filtering)
  • min_p set to 0.0 (no filtering)

The logic properly accounts for each batch's draft tokens plus one additional token.

top_k = request.sampling_config.top_k[0]

if top_k <= 0:
top_k = 2147483647
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use int64.max or something else, for searchability/readability, instead of a magic number

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since python doesn't have built-in max constant, we now define TOP_K_DISABLED as a large value (greater than vocab size). Any better suggestions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -374,6 +381,23 @@ def __init__(
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
)
self.max_draft_len = spec_config.max_draft_len

if self.is_advanced_mtp_sampler:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please colocate all of relevant code within the MTPWorker, not model_engine.py. I should be able to read the sampling logic in MTP top-down, ideally, and model_engine.py is complex enough as it is

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We retain the logic for collecting request-specific speculative parameters in model_engine, as moving it to the MTP side would require duplicating the loop over all requests and passing multiple parameters into model_engine. Instead, we encapsulate this logic in a helper function to keep the code cleaner.
Other initialization steps and SpecMetadata related settings have been moved to the MTP side.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @netanel-haber, I don't think collecting sampling parameters in model_engine is a good design choice. Do these parameters even belong into spec_metadata or should we create a separate class for them? They seem general enough to justify an independent class. (They also don't seem to be specific to MTP so the new functions shouldn't include mtp in the names)

Copy link
Collaborator

@jhaotingc jhaotingc Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Funatiq @netanel-haber, I agree it's a suboptimal solution putting the samplings in spec_metadata.
Is there a plan to have make_decoding_batch_input or similar design in TorchSampler?
What's the roadmap for TorchSampler having advanced sampling feature? Or a design we can follow?

One reason of having those in speculative related class and also parsing it in model_engine is because, sample_async in py_executor happens after model forward, while MTP is part of model forward. So a better design is to add extra logics in py_executor, iterating through scheduled_batch, parsing the sampling params, before model forward, or iterate that during MTPWorker logit sampling step.

This PR is here because the feature has been lacking for long, we received many customer requests wanting this feature, so we help merging this community contributing PR to main. @laikhtewari

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12801 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9537 completed with status: 'FAILURE'

@nvxuanyuc nvxuanyuc force-pushed the deepinfra_mtp_torch_sampler branch from 74fa84e to 9c4da6b Compare July 25, 2025 01:31
@nvxuanyuc
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12919 [ run ] triggered by Bot

@@ -405,6 +405,7 @@ class MTPDecodingConfig(DecodingBaseConfig):
relaxed_topk: int = 1
relaxed_delta: float = 0.
use_mtp_vanilla: bool = False
use_advanced_mtp_sampler: Optional[bool] = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the type be bool? Does None have a special meaning beyond False?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. It should be bool and default False, though the flag has been removed in the latest logic.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12919 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9633 completed with status: 'FAILURE'

@nvxuanyuc
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13016 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13016 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9721 completed with status: 'SUCCESS'

@nvxuanyuc nvxuanyuc force-pushed the deepinfra_mtp_torch_sampler branch from 9c4da6b to 0ee79b3 Compare July 29, 2025 21:08
@coderabbitai coderabbitai bot requested a review from Funatiq July 29, 2025 21:09
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

200-232: Top-k/top-p implementation is efficient but addresses past review concerns.

The implementation is mathematically correct and handles the complexity of combined top-k and top-p filtering. However, this directly relates to the past review comments about optimization when both k and p have default values.

From the past comments, I can see the author explained that device-side conditional execution isn't feasible with native torch ops and would break CUDA graph capture. The current approach of unconditional filtering is the correct design choice for maintaining performance and graph compatibility.

The sorting-based approach is necessary for top-p but can be expensive for large vocabularies. The current implementation is optimal given the constraints.

🧹 Nitpick comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

154-167: Function implementation is correct but consider parameter validation.

The forward_native function correctly combines top-k/top-p filtering with random sampling. The in-place modification warning is helpful.

However, consider adding basic parameter validation to ensure robustness:

def forward_native(
    logits: torch.Tensor,
    k: Optional[torch.Tensor],
    p: Optional[torch.Tensor],
) -> torch.Tensor:
    """
    PyTorch-native implementation of top-k and top-p sampling.

    The logits tensor may be updated in-place.
    """
+   if logits.dim() != 2:
+       raise ValueError(f"Expected 2D logits tensor, got {logits.dim()}D")
    logits = apply_top_k_top_p(logits, k, p)
    probs = logits.softmax(dim=-1, dtype=torch.float32)
    return random_sample(probs)
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9c4da6b and 0ee79b3.

📒 Files selected for processing (6)
  • tensorrt_llm/_torch/model_config.py (1 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (10 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (5 hunks)
  • tests/unittest/_torch/speculative/test_mtp.py (1 hunks)
✅ Files skipped from review due to trivial changes (2)
  • tensorrt_llm/_torch/model_config.py
  • tensorrt_llm/_torch/speculative/interface.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/unittest/_torch/speculative/test_mtp.py
  • tensorrt_llm/_torch/speculative/mtp.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
tensorrt_llm/_torch/pyexecutor/model_engine.py (4)

Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Learnt from: moraxu
PR: #6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Learnt from: yiqingy0
PR: #5198
File: jenkins/mergeWaiveList.py:0-0
Timestamp: 2025-07-22T08:33:49.109Z
Learning: In the TensorRT-LLM waive list merging system, removed lines are always located at the end of the merge waive lists, which is why the mergeWaiveList.py script uses reverse traversal - it's an optimization for this specific domain constraint.

Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.402Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.

🔇 Additional comments (16)
tensorrt_llm/_torch/pyexecutor/sampler.py (5)

4-4: Import addition looks good.

The Optional import is correctly added and used throughout the new sampling functions.


169-178: Excellent CPU-GPU synchronization avoidance technique.

The random_sample function cleverly uses the Gumbel-max trick (exponential noise + argmax) to avoid torch.multinomial's CPU-GPU synchronization. This is a well-known technique for maintaining CUDA graph compatibility.

The implementation is mathematically sound and performance-optimized.


180-198: Min-p implementation is mathematically correct.

The adaptive probability thresholding implementation correctly:

  1. Converts logits to probabilities
  2. Finds maximum probability per sequence
  3. Scales min_p threshold by the maximum probability
  4. Masks tokens below the adaptive threshold

The logic aligns with the min-p sampling strategy used in modern language models.


234-236: Greedy sampling implementation is correct and efficient.

Simple argmax implementation that correctly flattens the output to 1D as expected by the sampling interface.


238-244: Temperature scaling is efficient with in-place operation.

The in-place division correctly scales logits by temperature while avoiding unnecessary tensor allocation. This is optimal for performance.

tensorrt_llm/_torch/pyexecutor/model_engine.py (11)

20-20: LGTM!

The import of LlmRequest is necessary for accessing sampling configuration parameters in the new helper function.


282-283: LGTM!

The initialization logic correctly combines the three required conditions for enabling advanced MTP sampling: speculative decoding, MTP mode, and mixed sampler configuration.


301-301: LGTM!

The enable_mixed_sampler parameter is correctly passed to the model loading function, maintaining consistency with other configuration parameters.


1179-1223: LGTM!

The helper function is well-implemented with several good practices:

  • Temperature clamping to 0.01 prevents numerical instability
  • Uses torch.iinfo(torch.int32).max instead of magic numbers for disabled top_k
  • Proper handling of None/empty sampling config values
  • Clear separation of concerns with individual parameter extraction functions
  • Correct extension of parameters for draft tokens + main token

1173-1178: LGTM!

The sampling parameter lists are properly initialized when advanced MTP sampling is enabled.


1257-1258: LGTM!

Sampling parameters are correctly collected for context requests with no draft tokens.


1343-1344: LGTM!

Sampling parameters are correctly collected for extend requests, passing the appropriate draft token length.


1375-1376: LGTM!

Sampling parameters are correctly collected for requests with previous batch data, using the maximum draft length.


1410-1411: LGTM!

Sampling parameters are correctly collected for generation requests, using the maximum draft length for consistency.


1606-1608: LGTM!

The spec_metadata is correctly updated with the collected sampling parameters when advanced MTP sampling is enabled.


1159-1161: LGTM!

The advanced MTP sampling setup is correctly invoked during the forward pass with appropriate batch size and draft length parameters.

Comment on lines +246 to +264
def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor,
top_k: torch.Tensor, top_p: torch.Tensor,
min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
raw_probs = torch.softmax(logits, dim=-1)
greedy_sampled = greedy_sample(logits)
logits = apply_temperature(logits, temperatures)
logits = apply_min_p(logits, min_p)
random_sampled = forward_native(logits, top_k, top_p)
next_tokens = torch.where(
temperatures <= 1e-2, # Match the clamping threshold
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
token_probs = torch.gather(raw_probs, dim=1,
index=next_tokens.unsqueeze(1)).squeeze(-1)
log_probs = torch.log(token_probs)
return next_tokens, log_probs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Unified sampling function is well-designed but has a potential numerical stability issue.

The sampling_batch function effectively combines all sampling strategies with a temperature threshold for greedy vs. random sampling. The design is sound, but there's a potential numerical issue:

The log probability calculation at Line 262 could produce NaN values if token_probs contains zeros due to precision issues:

    token_probs = torch.gather(raw_probs, dim=1,
                               index=next_tokens.unsqueeze(1)).squeeze(-1)
-   log_probs = torch.log(token_probs)
+   log_probs = torch.log(token_probs.clamp(min=1e-8))

This prevents log(0) = -inf issues that could propagate through the system.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor,
top_k: torch.Tensor, top_p: torch.Tensor,
min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
raw_probs = torch.softmax(logits, dim=-1)
greedy_sampled = greedy_sample(logits)
logits = apply_temperature(logits, temperatures)
logits = apply_min_p(logits, min_p)
random_sampled = forward_native(logits, top_k, top_p)
next_tokens = torch.where(
temperatures <= 1e-2, # Match the clamping threshold
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
token_probs = torch.gather(raw_probs, dim=1,
index=next_tokens.unsqueeze(1)).squeeze(-1)
log_probs = torch.log(token_probs)
return next_tokens, log_probs
def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor,
top_k: torch.Tensor, top_p: torch.Tensor,
min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
raw_probs = torch.softmax(logits, dim=-1)
greedy_sampled = greedy_sample(logits)
logits = apply_temperature(logits, temperatures)
logits = apply_min_p(logits, min_p)
random_sampled = forward_native(logits, top_k, top_p)
next_tokens = torch.where(
temperatures <= 1e-2, # Match the clamping threshold
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
token_probs = torch.gather(
raw_probs,
dim=1,
index=next_tokens.unsqueeze(1),
).squeeze(-1)
log_probs = torch.log(token_probs.clamp(min=1e-8))
return next_tokens, log_probs
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/sampler.py around lines 246 to 264, the log
probability calculation uses torch.log on token_probs which may contain zeros,
causing NaN values. To fix this, clamp token_probs to a small positive value
(e.g., 1e-10) before applying torch.log to avoid log(0) and ensure numerical
stability.

@nvxuanyuc
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13420 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13420 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #10039 completed with status: 'FAILURE'

@nvxuanyuc
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13502 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13502 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #10115 completed with status: 'SUCCESS'

# Default to greedy mode. If true, use advanced pytorch sampling strategy.
self.enable_mixed_sampler = False
if self.model_config is not None:
self.enable_mixed_sampler = self.model_config.enable_mixed_sampler
Copy link
Collaborator

@ixlmar ixlmar Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: This could be a @property rather than a copy, to avoid potential consistency issues in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Community want to contribute PRs initiated from Community
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants