-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[TRTLLM-7353][feat] Implement capturable drafting loops for speculation #7100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughAdds an optional drafting_loop_wrapper to PyTorchModelEngine and threads it through executor creation, introduces ChainDrafter and drafting_loops for capturable drafting, updates ModelDrafter to support a static drafting path when the draft engine is wrapped, and adjusts KV cache dummy-request logic to account for extra decoding steps. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Creator as py_executor_creator
participant Engine as PyTorchModelEngine
participant DraftModel as DraftModel
participant Chain as ChainDrafter
participant Drafter as ModelDrafter
Creator->>Engine: create draft_model_engine(drafting_loop_wrapper=?)
alt wrapper enabled
Engine->>DraftModel: _load_model()
Engine->>Engine: wrap with drafting_loop_wrapper -> ChainDrafter(DraftModel)
Note right of Engine #DDEBF7: model_is_wrapped = True
else no wrapper
Engine->>DraftModel: _load_model()
Note right of Engine #F7F7F7: model_is_wrapped = False
end
Creator->>Drafter: init(draft_model_engine)
Drafter->>Drafter: use_static_draft_loop = model_is_wrapped
alt static draft loop (wrapped)
Drafter->>Chain: forward(input_ids, position_ids, attn_meta, spec_meta)
Chain->>DraftModel: initial forward -> logits
loop up to max_draft_len-1
Chain->>DraftModel: step forward with last token
Chain->>Chain: sample tokens (argmax ± d2t)
end
Chain-->>Drafter: stacked drafted tokens
Drafter-->>Creator: prepared draft tokens (early return)
else iterative (classic)
Drafter->>DraftModel: forward (+ guided decoding/sampling loop)
Drafter-->>Creator: prepared draft tokens
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (7)
tensorrt_llm/llmapi/llm_args.py (2)
403-405
: Add type hints and a concise docstring to the new hook.The default hook is public API surface; add typing and a short docstring for discoverability and consistency with the rest of the file’s style.
Apply this diff:
- def get_draft_model_wrapper(self, model): - return None + def get_draft_model_wrapper(self, model: torch.nn.Module) -> Optional[torch.nn.Module]: + """Optionally wrap a loaded draft model before use by the PyTorch engine. + + Args: + model: The instantiated draft model. + Returns: + A wrapper module to be used instead of the raw model, or None to use the model as-is. + """ + return None
449-453
: Validate inputs and document the Eagle-specific wrapper.Two small hardening tweaks:
- Guard against an unset/invalid max_draft_len (defensive even if upstream validation usually sets it).
- Add a short docstring explaining the behavior.
Apply this diff:
- def get_draft_model_wrapper(self, model): - from tensorrt_llm._torch.speculative.eagle3 import ChainDrafter - - return ChainDrafter(self.max_draft_len, model) + def get_draft_model_wrapper(self, model: torch.nn.Module) -> torch.nn.Module: + """Wrap the draft model with ChainDrafter for EAGLE(3) drafting.""" + from tensorrt_llm._torch.speculative.eagle3 import ChainDrafter + if not self.max_draft_len or self.max_draft_len <= 0: + raise ValueError("EagleDecodingConfig.max_draft_len must be > 0 when wrapping the draft model.") + return ChainDrafter(self.max_draft_len, model)If there are call sites that can invoke this before BaseLlmArgs validators run, consider keeping the ValueError; otherwise we can rely on existing validation.
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
391-410
: Unify runtime draft length handling and remove duplication.Today, max_num_draft_tokens uses spec_config.max_draft_len while self.max_draft_len is set separately based on is_draft_model. Define a single runtime_draft_len and use it consistently; this reduces risk of divergence and makes intent explicit.
Also, the “TODO hack” makes without_logits depend on draft-model + Eagle3. That’s fine short-term, but please centralize this policy in spec_dec_mode if we keep it.
Apply this diff:
- max_num_draft_tokens = self.spec_config.max_draft_len * batch_size if not self.is_draft_model else 0 + runtime_draft_len = 0 if self.is_draft_model else self.spec_config.max_draft_len + max_num_draft_tokens = runtime_draft_len * batch_size @@ - # TODO undo this hack - self.without_logits = self.spec_config.spec_dec_mode.without_logits( - ) or (self.is_draft_model - and self.spec_config.spec_dec_mode.is_eagle3()) - self.max_draft_len = spec_config.max_draft_len if not self.is_draft_model else 0 + # TODO undo this hack: surface a formal API on spec_dec_mode for draft-model outputs. + self.without_logits = ( + self.spec_config.spec_dec_mode.without_logits() + or (self.is_draft_model and self.spec_config.spec_dec_mode.is_eagle3()) + ) + self.max_draft_len = runtime_draft_lenFollow-up: _maybe_get_cuda_graph still uses spec_config.max_draft_len. Please verify we should switch it to self.runtime_draft_len for draft models to avoid capturing graphs with inconsistent draft_len. If you want, I can send a small targeted patch for that call site.
1155-1157
: Wrapping after compile may skip compilation of the wrapper.The draft model wrapper is applied after we potentially compile the model with torch.compile above. If the wrapper’s forward will participate in the graph, consider wrapping before compilation so the compiled artifact includes the wrapper path; otherwise document that the draft side runs uncompiled.
One option is to move the wrapper application right before the torch.compile call, gated on is_draft_model. If keeping it here, please confirm perf tradeoffs are acceptable for the draft engine.
tensorrt_llm/_torch/speculative/model_drafter.py (1)
261-263
: Reconsider disabling d2t propagation.Guided decoding consumers may rely on d2t offsets. If the intention is to keep it optional, gate it with a flag rather than commenting out. At minimum, keep outputs.get('d2t') populated when the model exposes it.
Example minimal change:
- # if hasattr(self.draft_model_engine.model.model, 'd2t'): - # outputs['d2t'] = self.draft_model_engine.model.model.d2t.data + if hasattr(self.draft_model_engine.model, "model") and hasattr(self.draft_model_engine.model.model, "d2t"): + outputs['d2t'] = self.draft_model_engine.model.model.d2ttensorrt_llm/_torch/speculative/eagle3.py (2)
498-505
: Validate inputs and document ChainDrafter.Add a brief class/method docstring and enforce max_draft_len > 0 defensively.
Apply this diff:
class ChainDrafter(torch.nn.Module): def __init__(self, max_draft_len: int, draft_model: torch.nn.Module): super().__init__() + """A lightweight wrapper to generate chained draft tokens entirely on-device. + + Args: + max_draft_len: Number of draft tokens to generate per step. + draft_model: Underlying draft model implementing forward() and exposing .model/.config. + """ + if max_draft_len <= 0: + raise ValueError("ChainDrafter requires max_draft_len > 0.") self.draft_model = draft_model self.config = self.draft_model.config self.model_config = self.draft_model.model_config self.max_draft_len = max_draft_len
505-561
: Be explicit about expected tensor shapes in forward and guard invariants.Forward assumes draft_model.forward returns logits shaped per-seq and that argmax yields a 1D tensor of size batch_size. Add assertions to reduce surprises, and make the CUDA-graph save/restore section robust to partial batches.
Example guards:
@@ - logits = self.draft_model.forward(input_ids=input_ids, + logits = self.draft_model.forward(input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata, spec_metadata=spec_metadata) + assert torch.is_tensor(logits), "ChainDrafter expects tensor logits from draft_model.forward" @@ - new_draft_tokens = [self.sample(logits)] + first_tokens = self.sample(logits) + assert first_tokens.dim() == 1 and first_tokens.numel() == batch_size, \ + f"Expected 1D tokens of size batch_size={batch_size}, got shape {tuple(first_tokens.shape)}" + new_draft_tokens = [first_tokens]
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
tensorrt_llm/_torch/pyexecutor/model_engine.py
(3 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(0 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(2 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(3 hunks)tensorrt_llm/llmapi/llm_args.py
(2 hunks)
💤 Files with no reviewable changes (1)
- tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/llmapi/llm_args.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/speculative/eagle3.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/llmapi/llm_args.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/speculative/eagle3.py
🧠 Learnings (1)
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
PR: NVIDIA/TensorRT-LLM#7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
🔇 Additional comments (1)
tensorrt_llm/_torch/speculative/eagle3.py (1)
2-2
: LGTM: typing import expansion.Import additions are fine and scoped. No action needed.
cdc0a2a
to
dc76df7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
♻️ Duplicate comments (3)
tensorrt_llm/_torch/speculative/model_drafter.py (2)
75-76
: Defaulting use_static_draft_loop to True flips global behavior and bypasses correctness checks. Make this opt-in.This change alters the drafter’s control flow for all users and skips validation/guided-decoding/cleanup paths. It should be opt-in until parity and perf are verified.
Apply this minimal diff here:
- self.use_static_draft_loop = True + self.use_static_draft_loop = False # Opt-in until feature parity and perf are validatedAnd update the constructor signature (outside the changed range) to make this configurable:
# Suggested change outside selected range (constructor signature) def __init__(..., guided_decoder: Optional[GuidedDecoder] = None, use_static_draft_loop: bool = False): ... self.use_static_draft_loop = use_static_draft_loopRun to find call sites to adjust (expect 0-2 in PyTorch backend only):
#!/bin/bash rg -nP --type=py -C2 'ModelDrafter\('
383-392
: Static path is incorrect: treats outputs as a Tensor, copies to CPU, ignores allocation limits, leaks resources.
outputs
is a dict;outputs.cpu()
will raise.- Host copy defeats the perf objective.
- Appends tokens for all requests (context + generation) and ignores
py_draft_pages_allocated
.- Skips resource free, padding, and guided-decoder rollback.
Fix minimally or remove the fast-path until parity is achieved.
Apply this GPU-friendly fast-path that respects semantics:
- if self.use_static_draft_loop: - outputs_host = outputs.cpu() - for token_idx in range(self.max_draft_tokens): - for req_idx, req in enumerate(draft_batch.all_requests()): - target_req = req_id_to_old_request[req.py_request_id] - target_req.py_draft_tokens.append( - outputs_host[token_idx][req_idx]) - - return + if self.use_static_draft_loop: + # Expect Tensor of shape [max_draft_tokens, batch_size] or dict with 'next_draft_tokens' + tokens = outputs.get('next_draft_tokens', outputs) + assert isinstance(tokens, torch.Tensor), "Static path expects a Tensor or dict['next_draft_tokens']" + # Map only generation requests; contexts are prefills + ctx_count = len(draft_batch.context_requests) + gen_reqs = draft_batch.generation_requests + # [T, B] -> [B, T] + tokens_bt = tokens.transpose(0, 1).contiguous() + for i, req in enumerate(gen_reqs): + target_req = req_id_to_old_request[req.py_request_id] + # respect allocated pages + already = get_draft_token_length(target_req) + can_take = max(0, target_req.py_draft_pages_allocated - already) + n = min(self.max_draft_tokens, can_take) + if n > 0: + target_req.py_draft_tokens.extend(tokens_bt[ctx_count + i, :n].tolist()) + # Free draft seq slots for this pass + for req in draft_batch.all_requests(): + self.draft_seq_slot_manager.free_resources(req) + # Pad to max_draft_tokens for CUDA graph invariants + self._pad_to_max_draft_tokens(scheduled_requests) + # Keep guided-decoding invariants + if self.guided_decoder is not None: + self.guided_decoder.rollback_draft_tokens(scheduled_requests) + returnOptional: if the wrapper returns a different shape, document it and adjust indexing here accordingly.
tensorrt_llm/_torch/speculative/eagle3.py (1)
564-569
: Make d2t application robust and eliminate.data
usage.Guard attribute access, keep dtype integer, and avoid unsafe
.data
.Apply:
- def sample(self, logits: torch.Tensor) -> torch.Tensor: - tokens = torch.argmax(logits, dim=-1) - d2t = self.draft_model.model.d2t.data - - return tokens + d2t[tokens] + def sample(self, logits: torch.Tensor) -> torch.Tensor: + tokens = torch.argmax(logits, dim=-1).to(torch.int32) + d2t = getattr(getattr(self.draft_model, "model", None), "d2t", None) + if d2t is not None: + # Ensure index dtype + idx = tokens.to(torch.long) + return (tokens + d2t[idx]).to(torch.int32) + return tokens
🧹 Nitpick comments (5)
tensorrt_llm/_torch/speculative/model_drafter.py (1)
1-1
: Missing NVIDIA copyright header.Please prepend the 2025 NVIDIA header per repo guidelines.
Here’s a header snippet to add at file top (outside selected range):
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0tensorrt_llm/_torch/speculative/eagle3.py (2)
496-504
: Add class/method docstrings and type hints to ChainDrafter for clarity and API stability.Public class without docstrings makes usage unclear, especially for output shapes expected by ModelDrafter’s static path.
Apply:
-class ChainDrafter(torch.nn.Module): +class ChainDrafter(torch.nn.Module): + """ + Lightweight wrapper that chains up to `max_draft_len` single-token forwards on a draft model. + Returns a Tensor of shape [max_draft_len, batch_size] containing draft token ids (int32). + """ - def __init__(self, max_draft_len: int, draft_model: torch.nn.Module): + def __init__(self, max_draft_len: int, draft_model: torch.nn.Module): super().__init__() self.draft_model = draft_model self.config = self.draft_model.config self.model_config = self.draft_model.model_config self.max_draft_len = max_draft_len
1-1
: Missing NVIDIA copyright header.Please prepend the 2025 NVIDIA header per repo guidelines.
Header snippet (outside selected range):
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # http://www.apache.org/licenses/LICENSE-2.0tensorrt_llm/llmapi/llm_args.py (2)
403-405
: Add type hints and docstring to the base hook.Clarify the expected model type and return semantics.
Apply:
- def get_draft_model_wrapper(self, model): - return None + def get_draft_model_wrapper(self, model: torch.nn.Module) -> Optional[torch.nn.Module]: + """ + Optionally wrap the loaded draft model with a lightweight module (e.g., a chained drafter). + Return None to use the draft model as-is. + """ + return None
449-453
: Only wrap for two-model EAGLE and document output shape.When
eagle3_one_model
is True there is no separate draft model; returning a wrapper here can be confusing. Also, clarify wrapper output to match the static path.Apply:
- def get_draft_model_wrapper(self, model): - from tensorrt_llm._torch.speculative.eagle3 import ChainDrafter - - return ChainDrafter(self.max_draft_len, model) + def get_draft_model_wrapper(self, model: torch.nn.Module) -> Optional[torch.nn.Module]: + """ + Return a ChainDrafter wrapper for EAGLE3 two-model flow. + The wrapper must return next-draft tokens shaped [max_draft_len, batch_size] (int32). + """ + if self.eagle3_one_model: + return None + from tensorrt_llm._torch.speculative.eagle3 import ChainDrafter + return ChainDrafter(self.max_draft_len, model)If the intent is to use ChainDrafter for one-model as well, please point to the exact call site so we can validate the control flow and shapes end-to-end.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
tensorrt_llm/_torch/pyexecutor/model_engine.py
(3 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(0 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(2 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(3 hunks)tensorrt_llm/llmapi/llm_args.py
(2 hunks)
💤 Files with no reviewable changes (1)
- tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tensorrt_llm/_torch/pyexecutor/model_engine.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/llmapi/llm_args.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/llmapi/llm_args.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/speculative/eagle3.py (1)
509-515
: Confirm logits shape and preserve device; avoid implicit host transfers.Ensure
draft_model.forward(...)
returns logits on CUDA and with batch semantics compatible withtorch.argmax(logits, dim=-1)
yielding [batch_size].If the underlying model returns per-token logits with shape [N, V], clarify N equals batch_size at each step. If not, adjust
dim
argument or add a gather on the correct token indices. I can help wire a small unit test to lock this down.
dc76df7
to
882f98f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
tensorrt_llm/_torch/speculative/eagle3.py (3)
568-572
: Unsafe.data
access and missing guards in sample(); also ensure integer indexing dtype.Directly accessing
.data
is unsafe;d2t
may be absent or a Parameter/Tensor. Indexing requireslong
. Return dtype should stay consistent (int32 across the codebase). This can crash or silently mis-map tokens.Apply:
- def sample(self, logits: torch.Tensor) -> torch.Tensor: - tokens = torch.argmax(logits, dim=-1) - d2t = self.draft_model.model.d2t.data - - return tokens + d2t[tokens] + def sample(self, logits: torch.Tensor) -> torch.Tensor: + # Greedy sample + tokens_long = torch.argmax(logits, dim=-1).to(torch.long) + # Optional d2t (draft-to-target) offset map + d2t = getattr(getattr(self.draft_model, "model", None), "d2t", None) + if d2t is not None: + # d2t may be a Tensor/Parameter; keep device consistent + d2t = d2t.to(device=tokens_long.device) + tokens_long = tokens_long + d2t[tokens_long] + return tokens_long.to(torch.int32)
518-537
: Save/restore all mutated attn_metadata fields for CUDA-graph safety; guard optional fields.Only seq_lens are saved/restored. You also mutate kv_lens_cuda, host_request_types, and num_contexts; not restoring these can poison subsequent runs and invalidate graph replays. Also, guard access when kv cache is disabled or fields are absent.
Apply:
if attn_metadata.is_cuda_graph: - seq_len = attn_metadata._seq_lens[:batch_size].clone() - seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone() + seq_len = attn_metadata._seq_lens[:batch_size].clone() + seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone() + kv_lens_cuda_backup = (attn_metadata.kv_lens_cuda[:batch_size].clone() + if hasattr(attn_metadata, "kv_lens_cuda") else None) + num_contexts_backup = attn_metadata.num_contexts + host_req_types_backup = (attn_metadata.host_request_types[:attn_metadata.num_contexts].clone() + if hasattr(attn_metadata, "host_request_types") else None) @@ - attn_metadata.kv_lens_cuda[:batch_size] += 1 + if hasattr(attn_metadata, "kv_lens_cuda"): + attn_metadata.kv_lens_cuda[:batch_size] += 1 @@ - attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1) - attn_metadata.num_contexts = 0 + if (getattr(attn_metadata, "kv_cache_manager", None) is not None and + hasattr(attn_metadata, "host_request_types")): + attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1) + attn_metadata.num_contexts = 0 @@ if attn_metadata.is_cuda_graph: attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size]) attn_metadata._seq_lens_cuda[:batch_size].copy_( seq_len_cuda[:batch_size]) + if kv_lens_cuda_backup is not None and hasattr(attn_metadata, "kv_lens_cuda"): + attn_metadata.kv_lens_cuda[:batch_size].copy_(kv_lens_cuda_backup) + if host_req_types_backup is not None and hasattr(attn_metadata, "host_request_types"): + attn_metadata.host_request_types[:num_contexts_backup].copy_(host_req_types_backup) + attn_metadata.num_contexts = num_contexts_backup + attn_metadata.on_update()Also applies to: 561-565
538-547
: Hidden-state index mapping is incorrect for multi-request; must use per-slot start offsets.Using 0..B-1 for write indices will clobber unrelated slots when multiple requests coexist. Use SlotManager’s start_indices keyed by request_ids.
Apply:
- spec_metadata.eagle3_resource_manager.is_first_draft = False - spec_metadata.is_first_draft = False - spec_metadata.hidden_states_read_indices[:batch_size].copy_( - old_seqlens - 1) - spec_metadata.hidden_states_write_indices[:batch_size].copy_( - torch.arange( - batch_size, - dtype=spec_metadata.hidden_states_write_indices.dtype, - device=spec_metadata.hidden_states_write_indices.device)) - spec_metadata.num_tokens = batch_size + rm = spec_metadata.eagle3_resource_manager + rm.is_first_draft = False + spec_metadata.is_first_draft = False + req_ids = spec_metadata.request_ids + assert req_ids is not None and len(req_ids) == batch_size + start_offsets = torch.tensor( + [rm.start_indices[rm.slot_manager.get_slot(int(rid))] for rid in req_ids], + device=old_seqlens.device, dtype=old_seqlens.dtype) + read_idx = start_offsets + (old_seqlens - 1) + write_idx = start_offsets + (old_seqlens - 1) + spec_metadata.hidden_states_read_indices[:batch_size].copy_(read_idx) + spec_metadata.hidden_states_write_indices[:batch_size].copy_(write_idx) + spec_metadata.num_tokens = batch_size
🧹 Nitpick comments (2)
tensorrt_llm/_torch/speculative/eagle3.py (1)
1-1
: Missing NVIDIA copyright header.Please prepend the standard NVIDIA copyright header for 2025 to comply with repo guidelines.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
405-410
: Replace “hack” with an explicit spec-mode capability.Hard-wiring
without_logits
for (is_draft_model && is_eagle3) is brittle. Push this into spec mode (e.g.,spec_dec_mode.draft_model_returns_tokens()
), so behavior is declared by the mode, not the engine.Example:
- # TODO undo this hack - self.without_logits = self.spec_config.spec_dec_mode.without_logits( - ) or (self.is_draft_model - and self.spec_config.spec_dec_mode.is_eagle3()) + # Decide by spec mode contract + self.without_logits = self.spec_config.spec_dec_mode.without_logits() \ + or (self.is_draft_model and getattr(self.spec_config.spec_dec_mode, "draft_model_returns_tokens", lambda: False)())Then implement
draft_model_returns_tokens()
for Eagle3 mode.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
tensorrt_llm/_torch/pyexecutor/model_engine.py
(3 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(0 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(2 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(3 hunks)tensorrt_llm/llmapi/llm_args.py
(2 hunks)
💤 Files with no reviewable changes (1)
- tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tensorrt_llm/_torch/speculative/model_drafter.py
- tensorrt_llm/llmapi/llm_args.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/speculative/eagle3.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/speculative/eagle3.py
🧠 Learnings (1)
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
PR: NVIDIA/TensorRT-LLM#7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
🔇 Additional comments (3)
tensorrt_llm/_torch/speculative/eagle3.py (1)
566-566
: Stack along batch-major dimension; verify return contract.Most consumers expect next-draft tokens as [batch, draft_len]. Current
torch.stack(new_draft_tokens)
yields [draft_len, batch].Apply:
- return torch.stack(new_draft_tokens) + return torch.stack(new_draft_tokens, dim=1)Follow-up:
- Confirm the expected return type for the draft wrapper. If callers expect a dict (e.g., {"next_draft_tokens": ...}), consider returning that instead to make the contract explicit. I can patch this once you confirm.
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
391-399
: Zero-sized draft_tokens buffer for draft model: confirm all codepaths tolerate it.For draft engines,
max_num_draft_tokens=0
allocates an empty CUDA Tensor. Most slices handle this, but please confirm capture/warmup and spec_metadata assignment never assume a non-empty buffer (especially with CUDA graphs or overlap scheduler).
1158-1160
: Verify wrapper contract: inputs and outputs.
get_draft_model_wrapper(model)
replaces the model with e.g. ChainDrafter. Ensure:
- forward signature accepts the same kwargs used in
_forward_step
(it does via**kwargs
);- the return value matches the engine’s
without_logits
path contract. If downstream expects{"next_draft_tokens": tensor[B, L]}
, returning a bare Tensor will break it.Recommend asserting or documenting the expected return; I can update ChainDrafter to return a dict if needed.
882f98f
to
cf7ce3e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
994-1010
: Fix CUDA-graph draft_len mismatch for the drafter (uses full max_draft_len instead of 0).After introducing max_draft_len=0 for draft engines, _maybe_get_cuda_graph still uses self.spec_config.max_draft_len when self.enable_spec_decode is True. For the draft engine, this will capture graphs with a nonzero draft_len, which does not match the drafter’s allocations (0) and can break graph selection/capture.
Use the existing runtime_draft_len property to reflect the new semantics.
Apply:
- draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0 + draft_len = self.runtime_draft_len
♻️ Duplicate comments (4)
tensorrt_llm/_torch/speculative/eagle3.py (4)
518-533
: Restore all mutated attn_metadata fields for CUDA-graph correctness.You back up _seq_lens/_seq_lens_cuda but also mutate kv_lens_cuda, num_contexts and host_request_types and never restore them. This can poison subsequent graph captures/replays.
Apply:
if attn_metadata.is_cuda_graph: seq_len = attn_metadata._seq_lens[:batch_size].clone() seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone() + kv_lens_cuda = getattr(attn_metadata, 'kv_lens_cuda', None) + kv_lens_cuda_bkp = kv_lens_cuda[:batch_size].clone() if kv_lens_cuda is not None else None + num_ctx_bkp = attn_metadata.num_contexts + host_req_types_bkp = None + if hasattr(attn_metadata, 'host_request_types'): + host_req_types_bkp = attn_metadata.host_request_types[:attn_metadata.num_contexts].clone() @@ - attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1) - attn_metadata.num_contexts = 0 + if hasattr(attn_metadata, 'host_request_types'): + attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1) + attn_metadata.num_contexts = 0 @@ if attn_metadata.is_cuda_graph: attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size]) attn_metadata._seq_lens_cuda[:batch_size].copy_( seq_len_cuda[:batch_size]) + if kv_lens_cuda_bkp is not None: + attn_metadata.kv_lens_cuda[:batch_size].copy_(kv_lens_cuda_bkp) + attn_metadata.num_contexts = num_ctx_bkp + if host_req_types_bkp is not None: + attn_metadata.host_request_types[:attn_metadata.num_contexts].copy_(host_req_types_bkp)Also applies to: 563-567
534-536
: Don’t forcibly flip is_first_draft; preserve and restore original flags.Overriding these flags to False then True can break scheduling/warmup semantics and multi-iteration behavior. Save the originals and restore them.
Apply:
- spec_metadata.eagle3_resource_manager.is_first_draft = False - spec_metadata.is_first_draft = False + prev_rm_is_first_draft = spec_metadata.eagle3_resource_manager.is_first_draft + prev_spec_is_first_draft = spec_metadata.is_first_draft + spec_metadata.eagle3_resource_manager.is_first_draft = False + spec_metadata.is_first_draft = False @@ - spec_metadata.is_first_draft = True - spec_metadata.eagle3_resource_manager.is_first_draft = True + spec_metadata.is_first_draft = prev_spec_is_first_draft + spec_metadata.eagle3_resource_manager.is_first_draft = prev_rm_is_first_draftAlso applies to: 560-561
539-547
: Hidden-state indices are wrong for multi-request; use per-slot start offsets.Reading from old_write_indices[last_tokens_idx] and then resetting write indices to arange(B) ignores per-slot offsets and can clobber unrelated sequences. Compute indices via SlotManager.start_indices + per-request offsets.
Apply:
- old_write_indices = spec_metadata.hidden_states_write_indices - - spec_metadata.hidden_states_read_indices[:batch_size].copy_( - old_write_indices[last_tokens_idx]) - spec_metadata.hidden_states_write_indices[:batch_size].copy_( - torch.arange( - batch_size, - dtype=spec_metadata.hidden_states_write_indices.dtype, - device=spec_metadata.hidden_states_write_indices.device)) - spec_metadata.num_tokens = batch_size + rm = spec_metadata.eagle3_resource_manager + req_ids = spec_metadata.request_ids + assert req_ids is not None and len(req_ids) == batch_size + # Per-request start offsets in the global hidden_states buffer + start_offsets = torch.tensor( + [rm.start_indices[rm.slot_manager.get_slot(int(rid))] for rid in req_ids], + device=last_tokens_idx.device, dtype=last_tokens_idx.dtype) + read_idx = start_offsets + last_tokens_idx + spec_metadata.hidden_states_read_indices[:batch_size].copy_(read_idx) + # For chained drafting we keep writing the last-token row per seq + spec_metadata.hidden_states_write_indices[:batch_size].copy_(read_idx) + spec_metadata.num_tokens = batch_size @@ - if i == 0: - spec_metadata.hidden_states_read_indices[:batch_size].copy_( - spec_metadata.hidden_states_write_indices[:batch_size]) + # read/write indices already point to the per-seq last token; no change neededAlso applies to: 556-559
570-575
: Make sampling robust: avoid .data, handle missing d2t, ensure int dtype.Using .data is unsafe; d2t may be missing; tokens should be integer-typed.
Apply:
- def sample(self, logits: torch.Tensor) -> torch.Tensor: - tokens = torch.argmax(logits, dim=-1) - d2t = self.draft_model.model.d2t.data - - return tokens + d2t[tokens] + def sample(self, logits: torch.Tensor) -> torch.Tensor: + tokens = torch.argmax(logits, dim=-1).to(torch.int32) + d2t = getattr(getattr(self.draft_model, "model", None), "d2t", None) + if d2t is not None: + d2t = d2t.to(device=tokens.device) + return (tokens + d2t[tokens]).to(torch.int32) + return tokens
🧹 Nitpick comments (5)
tensorrt_llm/_torch/pyexecutor/model_engine.py (3)
391-395
: Allocate 0 draft-token slots for the drafter (good), but prefer self.batch_size for clarity.The intent to avoid allocating draft-token storage in the drafter engine is correct. Minor: using self.batch_size keeps it consistent with the rest of the class and avoids confusion if the constructor arg and instance differ.
Apply:
- max_num_draft_tokens = self.spec_config.max_draft_len * batch_size if not self.is_draft_model else 0 + max_num_draft_tokens = ( + (self.spec_config.max_draft_len if not self.is_draft_model else 0) + * self.batch_size + )
405-410
: Codify the “without_logits” condition and avoid mixing spec_config vs self.spec_config.
- The Eagle3 drafter not returning logits is fine, but the “TODO undo this hack” suggests uncertainty. Please document this contract (“draft model returns token tensors; target model returns logits/dicts”) and make it stable.
- Prefer self.spec_config for consistency.
- max_draft_len: use self.spec_config for consistency with the rest of the file.
Apply:
- # TODO undo this hack - self.without_logits = self.spec_config.spec_dec_mode.without_logits( - ) or (self.is_draft_model - and self.spec_config.spec_dec_mode.is_eagle3()) - self.max_draft_len = spec_config.max_draft_len if not self.is_draft_model else 0 + # Eagle3 two-model drafter does not return logits; it outputs draft tokens. + self.without_logits = ( + self.spec_config.spec_dec_mode.without_logits() + or (self.is_draft_model and self.spec_config.spec_dec_mode.is_eagle3()) + ) + self.max_draft_len = ( + self.spec_config.max_draft_len if not self.is_draft_model else 0 + )
1-1
: Missing NVIDIA copyright header.Per coding guidelines, prepend the current-year NVIDIA header to this source.
tensorrt_llm/_torch/speculative/eagle3.py (2)
498-507
: Add minimal class/method docstrings and types.ChainDrafter is a new public wrapper. Please add short Google-style docstrings for the class and its methods (forward/sample) to clarify contracts (shapes, dtypes: returns [batch, max_draft_len] int32, no logits).
1-1
: Missing NVIDIA copyright header.Per coding guidelines, prepend the current-year NVIDIA header to this source.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
tensorrt_llm/_torch/pyexecutor/model_engine.py
(3 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(0 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(2 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(3 hunks)tensorrt_llm/llmapi/llm_args.py
(2 hunks)
💤 Files with no reviewable changes (1)
- tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tensorrt_llm/llmapi/llm_args.py
- tensorrt_llm/_torch/speculative/model_drafter.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/speculative/eagle3.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/speculative/eagle3.py
🧠 Learnings (1)
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
PR: NVIDIA/TensorRT-LLM#7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
🔇 Additional comments (2)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
1158-1160
: Draft-model wrapper hook: good, but ensure wrapper is nn.Module and preserves expected API.Wrapping the model via config is a clean integration point. Please add a guard/assert that the returned wrapper is an nn.Module and provides the same forward signature expected by this engine (especially for without_logits=True paths). Consider logging the wrapper type for debugging.
You can quickly sanity-check at runtime with:
assert isinstance(model, torch.nn.Module), f"Draft wrapper must be nn.Module, got {type(model)}"Would you like me to add this assert and a debug log?
tensorrt_llm/_torch/speculative/eagle3.py (1)
2-2
: Typing import extension looks fine.The added Any and Dict are used below. No issues.
cf7ce3e
to
1a42cf9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (6)
tensorrt_llm/_torch/speculative/eagle3.py (6)
518-533
: Restore full attn_metadata state; add guards for optional fields; call on_update() after restore.As written, only _seq_lens/_seq_lens_cuda are saved/restored. kv_lens_cuda, num_contexts, and host_request_types are mutated but never restored, which can poison subsequent invocations and CUDA-graph captures. Also, kv_lens_cuda/host_request_types may be absent depending on backend; unguarded access can raise.
Apply this diff:
@@ - if attn_metadata.is_cuda_graph: - seq_len = attn_metadata._seq_lens[:batch_size].clone() - seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone() + if attn_metadata.is_cuda_graph: + seq_len = attn_metadata._seq_lens[:batch_size].clone() + seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone() + kv_lens_cuda_backup = (attn_metadata.kv_lens_cuda[:batch_size].clone() + if hasattr(attn_metadata, "kv_lens_cuda") else None) + num_contexts_backup = attn_metadata.num_contexts + host_req_types_backup = ( + attn_metadata.host_request_types[:attn_metadata.num_contexts].clone() + if hasattr(attn_metadata, "host_request_types") and attn_metadata.num_contexts > 0 else None + ) @@ - attn_metadata._seq_lens[:batch_size].fill_(1) - attn_metadata._seq_lens_cuda[:batch_size].fill_(1) - attn_metadata.on_update() - attn_metadata.kv_lens_cuda[:batch_size] += 1 + attn_metadata._seq_lens[:batch_size].fill_(1) + attn_metadata._seq_lens_cuda[:batch_size].fill_(1) + attn_metadata.on_update() + if hasattr(attn_metadata, "kv_lens_cuda"): + attn_metadata.kv_lens_cuda[:batch_size] += 1 @@ - attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1) + if hasattr(attn_metadata, "host_request_types") and attn_metadata.num_contexts > 0: + attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1) attn_metadata.num_contexts = 0 @@ - if attn_metadata.is_cuda_graph: - attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size]) - attn_metadata._seq_lens_cuda[:batch_size].copy_( - seq_len_cuda[:batch_size]) + if attn_metadata.is_cuda_graph: + attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size]) + attn_metadata._seq_lens_cuda[:batch_size].copy_(seq_len_cuda[:batch_size]) + if 'kv_lens_cuda_backup' in locals() and kv_lens_cuda_backup is not None and hasattr(attn_metadata, "kv_lens_cuda"): + attn_metadata.kv_lens_cuda[:batch_size].copy_(kv_lens_cuda_backup) + if 'num_contexts_backup' in locals(): + attn_metadata.num_contexts = num_contexts_backup + if 'host_req_types_backup' in locals() and host_req_types_backup is not None and hasattr(attn_metadata, "host_request_types"): + attn_metadata.host_request_types[:attn_metadata.num_contexts].copy_(host_req_types_backup) + attn_metadata.on_update()Also applies to: 563-567
534-536
: Don’t flip is_first_draft flags unconditionally; save and restore originals.Unconditional False→True toggling can break outer scheduling/warmup semantics across iterations. Capture pre-call values and restore them.
- spec_metadata.eagle3_resource_manager.is_first_draft = False - spec_metadata.is_first_draft = False + prev_rm_is_first_draft = spec_metadata.eagle3_resource_manager.is_first_draft + prev_spec_is_first_draft = spec_metadata.is_first_draft + spec_metadata.eagle3_resource_manager.is_first_draft = False + spec_metadata.is_first_draft = False @@ - spec_metadata.is_first_draft = True - spec_metadata.eagle3_resource_manager.is_first_draft = True + spec_metadata.is_first_draft = prev_spec_is_first_draft + spec_metadata.eagle3_resource_manager.is_first_draft = prev_rm_is_first_draftAlso applies to: 560-561
539-546
: Fix hidden-state index mapping for multi-seq buffers: use per-slot start_offsets.Indexing with [0..B-1] (and relying on old_write_indices[last_tokens_idx]) ignores slot offsets and can clobber unrelated requests in the global hidden_states buffer.
- spec_metadata.hidden_states_read_indices[:batch_size].copy_( - old_write_indices[last_tokens_idx]) - spec_metadata.hidden_states_write_indices[:batch_size].copy_( - torch.arange( - batch_size, - dtype=spec_metadata.hidden_states_write_indices.dtype, - device=spec_metadata.hidden_states_write_indices.device)) - spec_metadata.num_tokens = batch_size + rm = spec_metadata.eagle3_resource_manager + req_ids = spec_metadata.request_ids + assert req_ids is not None and len(req_ids) == batch_size + start_offsets = torch.tensor( + [rm.start_indices[rm.slot_manager.get_slot(int(rid))] for rid in req_ids], + dtype=spec_metadata.hidden_states_read_indices.dtype, + device=spec_metadata.hidden_states_read_indices.device + ) + # Use old seqlens to address last-token rows per sequence. + old_seqlens = (seq_len_cuda[:batch_size] if attn_metadata.is_cuda_graph + else attn_metadata.seq_lens_cuda[:batch_size]).to(torch.long) + read_idx = start_offsets + (old_seqlens - 1) + spec_metadata.hidden_states_read_indices[:batch_size].copy_(read_idx) + spec_metadata.hidden_states_write_indices[:batch_size].copy_(read_idx) + spec_metadata.num_tokens = batch_size
548-556
: Guard kv_lens updates inside the draft loop.kv_lens_cuda may not exist depending on backend. Avoid unconditional writes.
- attn_metadata.kv_lens_cuda[:batch_size] += 1 + if hasattr(attn_metadata, "kv_lens_cuda"): + attn_metadata.kv_lens_cuda[:batch_size] += 1
568-569
: Return [batch, max_draft_len] not [max_draft_len, batch].Current stacking yields transposed shape and breaks downstream consumers.
- return torch.stack(new_draft_tokens) + # [batch, max_draft_len] + return torch.stack(new_draft_tokens, dim=1)
570-575
: Make d2t application robust; avoid .data; ensure int32 indexing.Accessing .data is unsafe; d2t may be absent; ensure integer dtype/device consistency.
- def sample(self, logits: torch.Tensor) -> torch.Tensor: - tokens = torch.argmax(logits, dim=-1) - d2t = self.draft_model.model.d2t.data - - return tokens + d2t[tokens] + def sample(self, logits: torch.Tensor) -> torch.Tensor: + tokens = torch.argmax(logits, dim=-1).to(torch.int32) + d2t = getattr(getattr(self.draft_model, "model", None), "d2t", None) + if d2t is not None: + # d2t could be a Tensor/Parameter; no .data access + return (tokens + d2t[tokens]).to(torch.int32) + return tokens
🧹 Nitpick comments (2)
tensorrt_llm/_torch/speculative/eagle3.py (2)
498-506
: Add docstrings and type hints for public ChainDrafter API.Improves readability and aligns with guidelines (Google-style docstrings; explicit types).
-class ChainDrafter(torch.nn.Module): +class ChainDrafter(torch.nn.Module): + """Wrapper that generates a fixed-length chain of draft tokens for Eagle3 two-model mode. + + This module delegates to the underlying draft_model for forward passes, + while managing attention/speculative metadata and token sampling. + """ @@ - def __init__(self, max_draft_len: int, draft_model: torch.nn.Module): + def __init__(self, max_draft_len: int, draft_model: torch.nn.Module) -> None: super().__init__() self.draft_model = draft_model self.config = self.draft_model.config self.model_config = self.draft_model.model_config self.max_draft_len = max_draft_len @@ - def forward(self, input_ids, position_ids, attn_metadata, spec_metadata, - **kwargs): + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attn_metadata: Any, + spec_metadata: Any, + **kwargs + ) -> torch.Tensor:Also applies to: 507-515
1-1
: Add NVIDIA copyright header.Per repository guidelines, prepend the current year NVIDIA copyright header.
Add this at the very top of the file:
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License");
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (6)
tensorrt_llm/_torch/models/modeling_speculative.py
(0 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py
(7 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(0 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(2 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(3 hunks)tensorrt_llm/llmapi/llm_args.py
(2 hunks)
💤 Files with no reviewable changes (2)
- tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
- tensorrt_llm/_torch/models/modeling_speculative.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tensorrt_llm/llmapi/llm_args.py
- tensorrt_llm/_torch/speculative/model_drafter.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
🧠 Learnings (1)
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
PR: NVIDIA/TensorRT-LLM#7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
🧬 Code graph analysis (2)
tensorrt_llm/_torch/speculative/eagle3.py (3)
tensorrt_llm/_torch/models/modeling_utils.py (1)
config
(500-501)tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
load_weights_from_target_model
(2419-2428)tensorrt_llm/_torch/models/modeling_speculative.py (1)
load_weights_from_target_model
(286-291)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
tensorrt_llm/_torch/speculative/interface.py (3)
without_logits
(53-54)is_eagle3
(32-33)extend_ctx
(87-100)tensorrt_llm/llmapi/llm_args.py (5)
spec_dec_mode
(386-393)spec_dec_mode
(434-439)spec_dec_mode
(533-538)get_draft_model_wrapper
(395-396)get_draft_model_wrapper
(441-444)
🔇 Additional comments (6)
tensorrt_llm/_torch/pyexecutor/model_engine.py (6)
470-472
: runtime_draft_len gating correctly excludes draft engines.This avoids double-counting draft tokens and simplifies overlap scheduler logic.
994-995
: CUDA graph draft_len gating looks correct.Setting draft_len=0 for draft engines prevents capturing incompatible spec-dec paths for the drafter.
1028-1030
: Setting spec_metadata.max_draft_len=0 for draft engines during CUDA-graph capture is appropriate.Prevents graph shape mismatches when running the drafter model through graphs.
391-395
: Confirm safe handling of zero-lengthdraft_tokens_cuda
Please verify that, when
is_draft_model=True
, all downstream slices and copy-operations on the zero-lengthdraft_tokens_cuda
buffer either never execute or always operate on empty tensors. In particular, review these locations:
Forward pass copy from CPU to CUDA:
self.draft_tokens_cuda[:len(draft_tokens)].copy_(draft_tokens, non_blocking=True)(model_engine.py:1531)
Forward pass merge of previous-batch tokens:
self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens + previous_batch_draft_tokens].copy_(…)(model_engine.py:1550)
CUDA-graph metadata binding slices:
spec_metadata.draft_tokens = self.draft_tokens_cuda[:total_draft_lens](model_engine.py:1691 and 1834)
Ensure that in all of the above, either (a) the enclosing branch never runs with a non-zero length buffer under
is_draft_model=True
, or (b) additional guards are added so no non-empty copy into an empty tensor is attempted.
1160-1162
: draft-model weight tie-in occurs post-wrap
Verified thatget_draft_model_wrapper
is applied in ModelEngine._load_model (model_engine.py:1160–1162) before callingdraft_model_engine.load_weights_from_target_model(model_engine.model)
in py_executor_creator.py:273, ensuring the wrapper’sload_weights_from_target_model
runs post-wrap. No action needed.
406-410
: without_logits override semantics verified
- In
SpeculativeDecodingMode.without_logits()
(interface.py 53–55), only MTP and EAGLE3_ONE_MODEL return true, matching one-model flows.- The added clause in
ModelEngine.__init__
(model_engine.py 406–408) applies only whenis_draft_model
and the mode is EAGLE3, so draft engines in the two-model flow skip logits while final engines still compute them.- MTP/EAGLE hybrids (spec_dec_mode == MTP_EAGLE) remain unaffected and continue to return logits as before.
No changes needed.
1a42cf9
to
d202608
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/speculative/model_drafter.py (1)
1-1
: Missing NVIDIA copyright header (2025).Per coding guidelines, prepend the current-year NVIDIA copyright header to all Python sources.
tensorrt_llm/llmapi/llm_args.py (1)
1-1
: Missing NVIDIA copyright header (2025).Per coding guidelines, prepend the current-year NVIDIA copyright header to all Python sources.
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
1-1
: Missing NVIDIA copyright header (2025).Per coding guidelines, prepend the current-year NVIDIA copyright header to all Python sources.
2119-2129
: Correct DoRA detection logic inmodel_engine.py
[critical]The only remaining pointer-based DoRA check is in
tensorrt_llm/_torch/pyexecutor/model_engine.py
(lines 2119–2122). Please replace this inverted heuristic with an explicitis_dora
flag:• File & Location
tensorrt_llm/_torch/pyexecutor/model_engine.py:2119–2129
• Proposed minimal patch
- # BUG: pointer == 0 was used to detect DoRA (inverted/unreliable) - is_dora = module.scaling_vec_pointer == 0 + # Use an explicit flag for DoRA; default to False until we plumb it through configs + is_dora = bool(getattr(module, "is_dora", False)) weights_in_pointer = module.weights_in_pointer weights_out_pointer = module.weights_out_pointer scaling_vec_pointer = module.scaling_vec_pointer if weights_in_pointer is None: weights_in_pointer = 0 if weights_out_pointer is None: weights_out_pointer = 0 if scaling_vec_pointer is None: scaling_vec_pointer = 0• Next steps
- Ensure the C++/nanobind/pybind bindings for
TaskLayerModuleConfig
expose anis_dora
boolean.- Populate that field when building the module list in the engine (rather than inferring from a pointer).
- Verify that downstream logic (e.g., packing into
lora_params
) behaves correctly with the explicit flag.This change eliminates the inverted, unreliable pointer-based detection and centralizes DoRA support on a proper boolean field.
♻️ Duplicate comments (8)
tensorrt_llm/_torch/speculative/model_drafter.py (3)
75-76
: Defaulting use_static_draft_loop=True is risky; make it opt-in.This flips global behavior and bypasses iterative/guided paths by default. Set False by default (and ideally expose a constructor arg).
Apply minimal change now:
- self.use_static_draft_loop = True + self.use_static_draft_loop = False # Opt-in until parity/perf are validatedIf desired, I can follow up with a patch to add a constructor parameter and docstring.
264-266
: Don’t silently drop d2t; re-attach under guard without .data.Guided decoding and vocab remapping need d2t when present.
- # if hasattr(self.draft_model_engine.model.model, 'd2t'): - # outputs['d2t'] = self.draft_model_engine.model.model.d2t.data + d2t = getattr(getattr(self.draft_model_engine.model, "model", None), "d2t", None) + if d2t is not None: + outputs['d2t'] = d2t
385-394
: Static path is functionally incorrect (TypeError on dict, CPU sync, ignores quotas).
- outputs is a dict from ModelEngine; calling .cpu() will crash.
- Moves data to CPU and appends indiscriminately.
- Ignores per-request draft allocation limits; doesn’t free seq slots; misses padding and guided rollback.
Replace the block with a GPU-friendly, quota-respecting mapping:
- if self.use_static_draft_loop: - outputs_host = outputs.cpu() - for token_idx in range(self.max_draft_tokens): - for req_idx, req in enumerate(draft_batch.all_requests()): - target_req = req_id_to_old_request[req.py_request_id] - target_req.py_draft_tokens.append( - outputs_host[token_idx][req_idx]) - - return + if self.use_static_draft_loop: + # Accept either Tensor or dict with 'next_draft_tokens' + tokens = outputs + if isinstance(tokens, dict): + tokens = tokens.get('next_draft_tokens', None) + assert isinstance(tokens, torch.Tensor), "Expected draft tokens Tensor or outputs['next_draft_tokens']" + # Normalize to [batch, max_draft_tokens] + if tokens.dim() == 2 and tokens.shape[0] == self.max_draft_tokens: + tokens_bt = tokens.transpose(0, 1) # [B, T] + else: + tokens_bt = tokens # assume [B, T] + # Map only generation requests; contexts come first + ctx_count = len(draft_batch.context_requests) + for i, req in enumerate(draft_batch.generation_requests): + target_req = req_id_to_old_request[req.py_request_id] + already = get_draft_token_length(target_req) + slots_remaining = max(0, target_req.py_draft_pages_allocated - already) + n = min(self.max_draft_tokens, slots_remaining) + if n > 0: + target_req.py_draft_tokens.extend(tokens_bt[ctx_count + i, :n].tolist()) + # Free resources and keep downstream assumptions + for req in draft_batch.all_requests(): + self.draft_seq_slot_manager.free_resources(req) + self._pad_to_max_draft_tokens(scheduled_requests) + if self.guided_decoder is not None: + self.guided_decoder.rollback_draft_tokens(scheduled_requests) + returntensorrt_llm/_torch/speculative/eagle3.py (5)
548-556
: Guard kv_lens_cuda increments in the loop.Avoid AttributeError where kv_lens_cuda is not present.
- attn_metadata.kv_lens_cuda[:batch_size] += 1 + if hasattr(attn_metadata, 'kv_lens_cuda'): + attn_metadata.kv_lens_cuda[:batch_size] += 1
526-533
: Guard metadata mutations for optional fields.kv_lens_cuda/host_request_types may be absent; unconditional writes can raise on certain backends.
- attn_metadata._seq_lens[:batch_size].fill_(1) - attn_metadata._seq_lens_cuda[:batch_size].fill_(1) - attn_metadata.on_update() - attn_metadata.kv_lens_cuda[:batch_size] += 1 + attn_metadata._seq_lens[:batch_size].fill_(1) + attn_metadata._seq_lens_cuda[:batch_size].fill_(1) + attn_metadata.on_update() + if hasattr(attn_metadata, 'kv_lens_cuda'): + attn_metadata.kv_lens_cuda[:batch_size] += 1 - - attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1) - attn_metadata.num_contexts = 0 + if hasattr(attn_metadata, 'host_request_types'): + attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1) + attn_metadata.num_contexts = 0
534-536
: Preserve and restore is_first_draft flags instead of toggling blindly.Unconditional False/True can break outer scheduling/warmup semantics.
- spec_metadata.eagle3_resource_manager.is_first_draft = False - spec_metadata.is_first_draft = False + prev_rm_is_first_draft = spec_metadata.eagle3_resource_manager.is_first_draft + prev_spec_is_first_draft = spec_metadata.is_first_draft + spec_metadata.eagle3_resource_manager.is_first_draft = False + spec_metadata.is_first_draft = False @@ - spec_metadata.is_first_draft = True - spec_metadata.eagle3_resource_manager.is_first_draft = True + spec_metadata.is_first_draft = prev_spec_is_first_draft + spec_metadata.eagle3_resource_manager.is_first_draft = prev_rm_is_first_draftAlso applies to: 560-561
568-568
: Output shape bug: return [batch, max_draft_len], not [max_draft_len, batch].Consumers expect [B, T].
- return torch.stack(new_draft_tokens) + return torch.stack(new_draft_tokens, dim=1)
571-574
: Avoid.data
and make d2t application robust.Accessing .data is unsafe; also ensure integer dtype and guard missing d2t.
- def sample(self, logits: torch.Tensor) -> torch.Tensor: - tokens = torch.argmax(logits, dim=-1) - d2t = self.draft_model.model.d2t.data - - return tokens + d2t[tokens] + def sample(self, logits: torch.Tensor) -> torch.Tensor: + tokens = torch.argmax(logits, dim=-1).to(torch.int32) + d2t = getattr(getattr(self.draft_model, "model", None), "d2t", None) + if d2t is not None: + # d2t is an offset table; index with long and keep int32 tokens + return (tokens + d2t[tokens.to(torch.long)]).to(torch.int32) + return tokens
🧹 Nitpick comments (2)
tensorrt_llm/_torch/speculative/model_drafter.py (1)
242-244
: CUDA graph gating may be too permissive when static loop is on.Unconditionally returning False when use_static_draft_loop is True assumes fully static shapes. If the wrapper output shape varies (e.g., batch composition), this can cause invalid graph replays.
Consider guarding by both the static-loop flag and a quick shape check (e.g., cache last shapes and only enable graphs if unchanged). Want a patch?
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
406-410
: Re-check without_logits gating for 2-model EAGLE3.You’re forcing without_logits for draft engines under EAGLE3 even if the draft wrapper or downstream expects logits. Confirm no paths rely on logits from the draft engine (e.g., mixed sampler, debugging).
If needed, key off a capability bit on the wrapper instead of is_eagle3().
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (6)
tensorrt_llm/_torch/models/modeling_speculative.py
(0 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py
(7 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(0 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(2 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(4 hunks)tensorrt_llm/llmapi/llm_args.py
(2 hunks)
💤 Files with no reviewable changes (2)
- tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
- tensorrt_llm/_torch/models/modeling_speculative.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/llmapi/llm_args.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/llmapi/llm_args.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
🧠 Learnings (1)
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
PR: NVIDIA/TensorRT-LLM#7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/speculative/eagle3.py (2)
tensorrt_llm/_torch/pyexecutor/model_engine.py (3)
forward
(77-85)forward
(2230-2321)load_weights_from_target_model
(2419-2428)tensorrt_llm/_torch/attention_backend/interface.py (6)
num_seqs
(252-256)seq_lens_cuda
(222-223)on_update
(151-161)num_contexts
(202-203)num_contexts
(206-209)num_tokens
(274-275)
tensorrt_llm/_torch/speculative/model_drafter.py (3)
tensorrt_llm/runtime/generation.py (1)
max_draft_tokens
(1283-1286)tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
all_requests
(38-39)tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
append
(79-98)append
(125-142)
tensorrt_llm/llmapi/llm_args.py (1)
tensorrt_llm/_torch/speculative/eagle3.py (1)
ChainDrafter
(498-578)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
tensorrt_llm/_torch/speculative/interface.py (3)
without_logits
(53-54)is_eagle3
(32-33)extend_ctx
(87-100)tensorrt_llm/llmapi/llm_args.py (2)
get_draft_model_wrapper
(403-404)get_draft_model_wrapper
(449-452)
🔇 Additional comments (4)
tensorrt_llm/_torch/pyexecutor/model_engine.py (4)
391-395
: LGTM: Zero draft-token buffer for draft engines.Allocating 0-sized draft_tokens_cuda for is_draft_model avoids wasted memory and clarifies roles of the two engines.
471-471
: LGTM: runtime_draft_len excludes draft engines.This prevents accidental draft allocation on the draft-engine execution path.
995-1004
: LGTM: CUDA graph keying excludes draft engines.Deriving draft_len=0 for draft engines keeps capture variants well-partitioned.
1160-1163
: Wrapper application: verify attributes and forward contract.ChainDrafter must expose config/model_config and a forward signature compatible with callers. The current wrapper sets those, but please confirm no downstream path accesses attributes only on the raw model.
I can script a repo-wide check to find attribute accesses on self.model in the draft path if helpful.
d202608
to
0874f2b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
393-412
: Critical: Ensure ‘logits’ key is always present for the static draft loopThe
ModelDrafter
always setsuse_static_draft_loop = True
(line 75), so the non‐static branch never executes. However, even in the static path the code still does:self._execute_guided_decoder(draft_batch, outputs['logits'], d2t=outputs.get('d2t'))If
spec_dec_mode.without_logits()
or the EAGLE3 draft engine (viaChainDrafter
) drops/logit‐less returns a raw tensor, this lookup will KeyError at runtime. You must choose one strategy:
- Restore
logits
in the dict returned byChainDrafter.forward
(e.g. wrap its raw output under'logits'
) wheneverwithout_logits=True
.- OR update
ModelDrafter._forward_draft_model
to detect a bare tensor return and wrap it into{'logits': tensor}
before entering the static loop.- OR adjust
spec_dec_mode.without_logits
so that we never enter the draft engine withwithout_logits=True
whenuse_static_draft_loop
is active.Please pick an approach and patch accordingly. If you’d like help with the diff on either side, let me know.
♻️ Duplicate comments (10)
tensorrt_llm/_torch/speculative/model_drafter.py (3)
75-76
: Defaulting use_static_draft_loop=True is risky; make it opt-in (False).This flips global behavior and bypasses validation/guided-decoding/cleanup paths by default. Make static loop opt-in until parity is locked.
- self.use_static_draft_loop = True + self.use_static_draft_loop = False # Opt-in until feature parity/perf are validated
263-266
: Safely re-attach d2t mapping when available; avoid.data
.The d2t propagation is commented out. Guided decoding and remapping paths depend on it when present.
- # if hasattr(self.draft_model_engine.model.model, 'd2t'): - # outputs['d2t'] = self.draft_model_engine.model.model.d2t.data + if isinstance(outputs, dict): + d2t = getattr(getattr(self.draft_model_engine.model, "model", None), "d2t", None) + if d2t is not None: + outputs['d2t'] = d2t # do not use .data
385-394
: Static draft path is functionally incorrect (CPU copy, wrong mapping, no cleanup). Fix fast-path.Current code:
- Moves outputs to CPU (performance regression).
- Treats outputs as [T, B] blindly; ignores contexts and per-request allocations.
- Leaks resources (no free) and skips padding/guided rollback.
Apply a GPU-friendly, correct fast-path that respects shape and invariants.
- if self.use_static_draft_loop: - outputs_host = outputs.cpu() - for token_idx in range(self.max_draft_tokens): - for req_idx, req in enumerate(draft_batch.all_requests()): - target_req = req_id_to_old_request[req.py_request_id] - target_req.py_draft_tokens.append( - outputs_host[token_idx][req_idx]) - - return + if self.use_static_draft_loop: + # Accept either Tensor [B, T] or [T, B], or dict with 'next_draft_tokens' + tokens = outputs + if isinstance(tokens, dict): + tokens = tokens.get('next_draft_tokens', None) + assert isinstance(tokens, torch.Tensor), "Expected Tensor or dict['next_draft_tokens']" + # Normalize to [B, T] + if tokens.dim() == 2 and tokens.shape[0] == draft_batch.batch_size: + tokens_bt = tokens + else: + tokens_bt = tokens.transpose(0, 1) + + ctx_count = len(draft_batch.context_requests) + gen_reqs = draft_batch.generation_requests + for i, req in enumerate(gen_reqs): + target_req = req_id_to_old_request[req.py_request_id] + # Respect allocated pages + current = get_draft_token_length(target_req) + budget = max(0, target_req.py_draft_pages_allocated - current) + n = min(self.max_draft_tokens, budget) + if n > 0: + target_req.py_draft_tokens.extend(tokens_bt[ctx_count + i, :n].tolist()) + + # Free seq slots allocated for this draft pass + for req in draft_batch.all_requests(): + self.draft_seq_slot_manager.free_resources(req) + # Ensure padding to max_draft_tokens for CUDA graph safety + self._pad_to_max_draft_tokens(scheduled_requests) + # Keep guided decoding invariants if present + if self.guided_decoder is not None: + self.guided_decoder.rollback_draft_tokens(scheduled_requests) + returntensorrt_llm/_torch/speculative/eagle3.py (7)
1-1
: Add NVIDIA Apache-2.0 copyright header (2025).This production source is missing the required header.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
558-566
: Guard kv_lens updates in the draft loop.kv_lens_cuda may be absent depending on backend; avoid AttributeError.
- attn_metadata.kv_lens_cuda[:batch_size] += 1 + if hasattr(attn_metadata, 'kv_lens_cuda'): + attn_metadata.kv_lens_cuda[:batch_size] += 1
528-543
: Guard attn_metadata mutations and restore state; call on_update after restoration.
- kv_lens_cuda and host_request_types may be absent; guard accesses.
- After restoring seq_lens, also call attn_metadata.on_update() for CUDA graph safety.
if attn_metadata.is_cuda_graph: seq_len = attn_metadata._seq_lens[:batch_size].clone() seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone() @@ - attn_metadata._seq_lens[:batch_size].fill_(1) - attn_metadata._seq_lens_cuda[:batch_size].fill_(1) - attn_metadata.on_update() - attn_metadata.kv_lens_cuda[:batch_size] += 1 + attn_metadata._seq_lens[:batch_size].fill_(1) + attn_metadata._seq_lens_cuda[:batch_size].fill_(1) + attn_metadata.on_update() + if hasattr(attn_metadata, 'kv_lens_cuda'): + attn_metadata.kv_lens_cuda[:batch_size] += 1 @@ - attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1) - attn_metadata.num_contexts = 0 + if hasattr(attn_metadata, 'host_request_types'): + attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1) + attn_metadata.num_contexts = 0 @@ - if attn_metadata.is_cuda_graph: - attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size]) - attn_metadata._seq_lens_cuda[:batch_size].copy_( - seq_len_cuda[:batch_size]) + if attn_metadata.is_cuda_graph: + attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size]) + attn_metadata._seq_lens_cuda[:batch_size].copy_(seq_len_cuda[:batch_size]) + attn_metadata.on_update()Also applies to: 573-577
544-546
: Preserve and restore original is_first_draft flags; don’t force-flip.Avoid breaking outer scheduling semantics by toggling without restoring.
- spec_metadata.eagle3_resource_manager.is_first_draft = False - spec_metadata.is_first_draft = False + prev_rm_is_first_draft = spec_metadata.eagle3_resource_manager.is_first_draft + prev_spec_is_first_draft = spec_metadata.is_first_draft + spec_metadata.eagle3_resource_manager.is_first_draft = False + spec_metadata.is_first_draft = False @@ - spec_metadata.is_first_draft = True - spec_metadata.eagle3_resource_manager.is_first_draft = True + spec_metadata.is_first_draft = prev_spec_is_first_draft + spec_metadata.eagle3_resource_manager.is_first_draft = prev_rm_is_first_draftAlso applies to: 570-571
549-556
: Hidden-state index mapping is wrong for multi-seq; use per-slot start offsets.Writing indices as [0..B-1] ignores per-slot offsets and can clobber unrelated sequences.
- spec_metadata.hidden_states_read_indices[:batch_size].copy_( - old_write_indices[last_tokens_idx]) - spec_metadata.hidden_states_write_indices[:batch_size].copy_( - torch.arange( - batch_size, - dtype=spec_metadata.hidden_states_write_indices.dtype, - device=spec_metadata.hidden_states_write_indices.device)) - spec_metadata.num_tokens = batch_size + rm = spec_metadata.eagle3_resource_manager + req_ids = spec_metadata.request_ids + assert req_ids is not None and len(req_ids) == batch_size + start_offsets = torch.tensor( + [rm.start_indices[rm.slot_manager.get_slot(int(rid))] for rid in req_ids], + device=old_write_indices.device, + dtype=old_write_indices.dtype) + read_idx = start_offsets + (attn_metadata.seq_lens_cuda[:batch_size] - 1) + write_idx = read_idx # write last token row per seq + spec_metadata.hidden_states_read_indices[:batch_size].copy_(read_idx) + spec_metadata.hidden_states_write_indices[:batch_size].copy_(write_idx) + spec_metadata.num_tokens = batch_size @@ - spec_metadata.hidden_states_read_indices[:batch_size].copy_( - spec_metadata.hidden_states_write_indices[:batch_size]) + spec_metadata.hidden_states_read_indices[:batch_size].copy_( + spec_metadata.hidden_states_write_indices[:batch_size])Also applies to: 567-569
578-578
: Fix output shape to [batch, max_draft_len].Stack along dim=1 to match downstream consumers (e.g., overlap scheduler).
- return torch.stack(new_draft_tokens) + # [batch, max_draft_len] + return torch.stack(new_draft_tokens, dim=1)
580-585
: Make d2t application robust; avoid.data
and ensure integer dtype.Use guarded access and dtype/device alignment.
- def sample(self, logits: torch.Tensor) -> torch.Tensor: - tokens = torch.argmax(logits, dim=-1) - d2t = self.draft_model.model.d2t.data - - return tokens + d2t[tokens] + def sample(self, logits: torch.Tensor) -> torch.Tensor: + tokens = torch.argmax(logits, dim=-1).to(torch.int32) + d2t = getattr(getattr(self.draft_model, "model", None), "d2t", None) + if d2t is not None: + d2t = d2t.to(device=tokens.device) + return (tokens + d2t[tokens]).to(torch.int32) + return tokens
🧹 Nitpick comments (5)
tensorrt_llm/_torch/speculative/model_drafter.py (2)
1-1
: Add NVIDIA Apache-2.0 copyright header (2025).Per the repository guidelines, prepend the standard NVIDIA Apache-2.0 header to all Python sources.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
242-244
: Clarify CUDA graph gating for static path.Returning False here disables “disable CUDA graph” (i.e., allows capture) whenever static loop is on. That’s fine only if the static loop guarantees fixed shapes and padding. Please confirm this with tests or gate it behind the same flag.
Would you like me to wire a minimal unit/env test that toggles this flag and asserts graph capture reuse across iterations?
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
1-1
: Add NVIDIA Apache-2.0 copyright header (2025).Please add the standard header to comply with repo policy.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (2)
1-1
: Add NVIDIA Apache-2.0 copyright header (2025).Production source; header is required.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
259-263
: wrap_model: ensure draft output shape contract is documented [B, T].ChainDrafter currently stacks along the batch/time in code; downstream consumers (static fast-path) expect [batch, draft_len]. Please document this contract here to avoid regressions, or enforce with a wrapper that transposes if needed.
Example enforcement (if needed, not necessarily here):
def wrap_model(model): from tensorrt_llm._torch.speculative.eagle3 import ChainDrafter - return ChainDrafter(spec_config.max_draft_len, model) + return ChainDrafter(spec_config.max_draft_len, model)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
tensorrt_llm/_torch/models/modeling_speculative.py
(0 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py
(9 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(2 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(2 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(4 hunks)
💤 Files with no reviewable changes (1)
- tensorrt_llm/_torch/models/modeling_speculative.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent with 4 spaces; do not use tabs
Preserve module namespace when importing: from package.subpackage import foo; then use foo.SomeClass()
Python filenames use snake_case (e.g., some_file.py)
Class names use PascalCase
Function and method names use snake_case
Local variables use snake_case; prefix k for names starting with a number (e.g., k_99th_percentile)
Global variables are UPPER_SNAKE_CASE prefixed with G (e.g., G_MY_GLOBAL)
Constants are UPPER_SNAKE_CASE
Avoid shadowing variables from an outer scope
Initialize all externally visible members of a class in init
For interfaces used outside a file, prefer docstrings over comments; comments for internal code or local interfaces
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Attributes and variables can be documented inline with trailing docstrings under the class or module
Avoid using reflection when easily avoidable; prefer explicit parameters/constructs over dict(**locals())
In try/except, catch the narrowest exception types possible
For duck-typing try/except, keep try body minimal and place logic in else after attribute existence checks
Files:
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
**/*.{h,hpp,hxx,hh,c,cc,cpp,cxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA Apache-2.0 copyright header with current year to all source files
Files:
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/speculative/eagle3.py
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
🧠 Learnings (2)
📚 Learning: 2025-08-25T11:33:08.068Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-25T11:33:08.068Z
Learning: Applies to **/*.{h,hpp,hxx,hh,c,cc,cpp,cxx,cu,cuh,py} : Prepend NVIDIA Apache-2.0 copyright header with current year to all source files
Applied to files:
tensorrt_llm/_torch/speculative/eagle3.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tensorrt_llm/_torch/speculative/eagle3.py
🧬 Code graph analysis (3)
tensorrt_llm/_torch/speculative/model_drafter.py (3)
tensorrt_llm/runtime/generation.py (1)
max_draft_tokens
(1283-1286)tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
all_requests
(38-39)tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
append
(79-98)append
(125-142)
tensorrt_llm/_torch/speculative/eagle3.py (4)
tensorrt_llm/_torch/pyexecutor/model_engine.py (3)
forward
(76-84)forward
(2054-2147)load_weights_from_target_model
(2245-2254)tensorrt_llm/_torch/models/modeling_speculative.py (5)
forward
(105-136)forward
(190-223)forward
(244-270)forward
(397-440)load_weights_from_target_model
(289-294)tensorrt_llm/_torch/speculative/mtp.py (2)
forward
(329-513)forward
(1123-1274)tensorrt_llm/_torch/attention_backend/interface.py (6)
num_seqs
(252-256)seq_lens_cuda
(222-223)on_update
(151-161)num_contexts
(202-203)num_contexts
(206-209)num_tokens
(274-275)
tensorrt_llm/_torch/pyexecutor/model_engine.py (4)
tensorrt_llm/_torch/speculative/interface.py (3)
without_logits
(53-54)is_eagle3
(32-33)extend_ctx
(87-100)tensorrt_llm/llmapi/llm_args.py (3)
spec_dec_mode
(395-402)spec_dec_mode
(441-446)spec_dec_mode
(549-554)tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
enable_spec_decode
(52-53)tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py (1)
model
(162-165)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (5)
tensorrt_llm/_torch/pyexecutor/model_engine.py (3)
277-279
: New model_wrapper hook in ctor: LGTM.The API extension is clear and typed; enables safe wrapping of the loaded model.
1016-1018
: Applying model_wrapper after load: LGTM.Wrapping post-initialization keeps loading logic unchanged and isolates the concern.
471-473
: runtime_draft_len excludes draft engines: LGTM.This avoids over-allocating buffers on the draft side.
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (2)
264-267
: Dummy-load forwarding for the draft engine: LGTM.Mirrors main-engine setting; keeps memory footprint predictable under tests.
283-284
: Passing model_wrapper into draft PyTorchModelEngine: LGTM.This cleanly isolates EAGLE3’s ChainDrafter to the draft engine only.
0874f2b
to
c24ef9e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
1-1
: Add the required NVIDIA Apache-2.0 header (2025).Production Python files must prepend the standard NVIDIA Apache-2.0 header with the current year.
Apply this diff at the very top:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +
♻️ Duplicate comments (5)
tensorrt_llm/_torch/speculative/eagle3.py (5)
1-1
: Add the required NVIDIA Apache-2.0 header (2025).Production Python files must prepend the standard NVIDIA Apache-2.0 header with the current year.
Apply this diff at the very top of the file:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +
544-546
: Preserve original is_first_draft flags instead of forcing False/True.Blindly toggling these flags can break outer scheduling semantics and warmup/capture assumptions. Save and restore.
Apply:
- spec_metadata.eagle3_resource_manager.is_first_draft = False - spec_metadata.is_first_draft = False + prev_rm_is_first_draft = spec_metadata.eagle3_resource_manager.is_first_draft + prev_spec_is_first_draft = spec_metadata.is_first_draft + spec_metadata.eagle3_resource_manager.is_first_draft = False + spec_metadata.is_first_draft = False @@ - spec_metadata.is_first_draft = True - spec_metadata.eagle3_resource_manager.is_first_draft = True + spec_metadata.is_first_draft = prev_spec_is_first_draft + spec_metadata.eagle3_resource_manager.is_first_draft = prev_rm_is_first_draftAlso applies to: 570-571
547-556
: Fix hidden-state index mapping (multi-seq correctness).Using old_write_indices[last_tokens_idx] and a 0..B-1 write range ignores per-slot start offsets and can clobber unrelated requests.
Apply:
- old_write_indices = spec_metadata.hidden_states_write_indices - - spec_metadata.hidden_states_read_indices[:batch_size].copy_( - old_write_indices[last_tokens_idx]) - spec_metadata.hidden_states_write_indices[:batch_size].copy_( - torch.arange( - batch_size, - dtype=spec_metadata.hidden_states_write_indices.dtype, - device=spec_metadata.hidden_states_write_indices.device)) + rm = spec_metadata.eagle3_resource_manager + req_ids = spec_metadata.request_ids + assert req_ids is not None and len(req_ids) == batch_size + start_offsets = torch.tensor( + [rm.start_indices[rm.slot_manager.get_slot(int(rid))] for rid in req_ids], + device=last_tokens_idx.device, + dtype=last_tokens_idx.dtype, + ) + rel_last = attn_metadata.seq_lens_cuda[:batch_size].to(start_offsets.dtype) - 1 + read_idx = start_offsets + rel_last + write_idx = start_offsets + rel_last + spec_metadata.hidden_states_read_indices[:batch_size].copy_(read_idx) + spec_metadata.hidden_states_write_indices[:batch_size].copy_(write_idx) spec_metadata.num_tokens = batch_size @@ - spec_metadata.hidden_states_read_indices[:batch_size].copy_( - spec_metadata.hidden_states_write_indices[:batch_size]) + spec_metadata.hidden_states_read_indices[:batch_size].copy_( + spec_metadata.hidden_states_write_indices[:batch_size])Also applies to: 567-569
578-578
: Return shape bug: stack along dim=1 to get [batch, max_draft_len].Current code returns [max_draft_len, batch]; consumers typically expect [batch, draft_len].
Apply:
- return torch.stack(new_draft_tokens) + # [batch, max_draft_len] + return torch.stack(new_draft_tokens, dim=1)
580-584
: Avoid .data and make d2t application robust; ensure integer dtype.Accessing .data is unsafe; d2t may be absent; and indices should be long. Return int32 for downstream consistency.
Apply:
- def sample(self, logits: torch.Tensor) -> torch.Tensor: - tokens = torch.argmax(logits, dim=-1) - d2t = self.draft_model.model.d2t.data - - return tokens + d2t[tokens] + def sample(self, logits: torch.Tensor) -> torch.Tensor: + tokens = torch.argmax(logits, dim=-1).to(torch.int32) + tokens_long = tokens.to(torch.long) + d2t = getattr(getattr(self.draft_model, "model", None), "d2t", None) + if d2t is not None: + return (tokens + d2t[tokens_long]).to(torch.int32) + return tokens
🧹 Nitpick comments (3)
tensorrt_llm/_torch/speculative/eagle3.py (1)
508-516
: Add docstrings to the new public class and its methods.Document ChainDrafter purpose, inputs/outputs, and state mutations (attn/spec metadata). This will help future maintainers and aligns with repo guidelines.
I can draft Google-style docstrings if helpful.
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
277-279
: New model_wrapper parameter: clarify contract.Consider documenting expected behavior of the wrapper (e.g., must return a torch.nn.Module exposing .config and .model_config, and a compatible forward). Helps future integrations.
910-912
: Return type is too narrow after introducing wrappers._model_engine._load_model can now return a wrapper not derived from DecoderModelForCausalLM. Loosen the annotation to torch.nn.Module to reduce confusion for readers and type checkers.
Apply:
-def _load_model(... ) -> DecoderModelForCausalLM: +def _load_model(... ) -> torch.nn.Module:
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
tensorrt_llm/_torch/models/modeling_speculative.py
(1 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py
(6 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(2 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(1 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- tensorrt_llm/_torch/speculative/model_drafter.py
- tensorrt_llm/_torch/models/modeling_speculative.py
- tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent with 4 spaces; do not use tabs
Preserve module namespace when importing: from package.subpackage import foo; then use foo.SomeClass()
Python filenames use snake_case (e.g., some_file.py)
Class names use PascalCase
Function and method names use snake_case
Local variables use snake_case; prefix k for names starting with a number (e.g., k_99th_percentile)
Global variables are UPPER_SNAKE_CASE prefixed with G (e.g., G_MY_GLOBAL)
Constants are UPPER_SNAKE_CASE
Avoid shadowing variables from an outer scope
Initialize all externally visible members of a class in init
For interfaces used outside a file, prefer docstrings over comments; comments for internal code or local interfaces
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Attributes and variables can be documented inline with trailing docstrings under the class or module
Avoid using reflection when easily avoidable; prefer explicit parameters/constructs over dict(**locals())
In try/except, catch the narrowest exception types possible
For duck-typing try/except, keep try body minimal and place logic in else after attribute existence checks
Files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/speculative/eagle3.py
**/*.{h,hpp,hxx,hh,c,cc,cpp,cxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA Apache-2.0 copyright header with current year to all source files
Files:
tensorrt_llm/_torch/pyexecutor/model_engine.py
tensorrt_llm/_torch/speculative/eagle3.py
🧠 Learnings (2)
📚 Learning: 2025-08-25T11:33:08.068Z
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-25T11:33:08.068Z
Learning: Applies to **/*.{h,hpp,hxx,hh,c,cc,cpp,cxx,cu,cuh,py} : Prepend NVIDIA Apache-2.0 copyright header with current year to all source files
Applied to files:
tensorrt_llm/_torch/speculative/eagle3.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tensorrt_llm/_torch/speculative/eagle3.py
🧬 Code graph analysis (2)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
tensorrt_llm/_torch/speculative/interface.py (2)
without_logits
(53-54)is_eagle3
(32-33)tensorrt_llm/llmapi/llm_args.py (3)
spec_dec_mode
(395-402)spec_dec_mode
(441-446)spec_dec_mode
(549-554)
tensorrt_llm/_torch/speculative/eagle3.py (3)
tensorrt_llm/_torch/models/modeling_utils.py (1)
config
(500-501)tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
load_weights_from_target_model
(2245-2254)tensorrt_llm/_torch/attention_backend/interface.py (6)
num_seqs
(252-256)seq_lens_cuda
(222-223)on_update
(151-161)num_contexts
(202-203)num_contexts
(206-209)num_tokens
(274-275)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (5)
tensorrt_llm/_torch/speculative/eagle3.py (1)
517-526
: Verify forward contract for draft engine when without_logits=True.ModelEngine._forward_step returns the raw result when without_logits=True. Ensure all callers of the draft path expect a Tensor of shape [B, draft_len] (after the dim fix) and not a dict.
Would you like me to scan the repo for usages and surface any callsites that assume a dict?
tensorrt_llm/_torch/pyexecutor/model_engine.py (4)
12-12
: Import addition looks good.
314-316
: Passing model_wrapper through to _load_model looks correct.
407-411
: Double-check the “without_logits” hack for Eagle3 draft engines.This makes the engine return the wrapper’s raw output. Verify all draft-path consumers handle non-dict outputs (Tensor) and that this doesn’t accidentally silence logits where needed (e.g., non-EAGLE3 modes).
I can scan usages to confirm assumptions if you want.
1016-1018
: Wrapper application point LGTM.Wrapping after weight load/finalize and stream sync minimizes surprises; attributes like .config/.model_config must be preserved by the wrapper (as done in ChainDrafter).
7e34691
to
0ec0734
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
tensorrt_llm/_torch/speculative/model_drafter.py (2)
268-270
: Don’t drop d2t mapping; reattach safely to preserve guided decoding invariantsCommenting this out breaks flows that need draft-to-target vocab remapping. Reattach under a guard and avoid .data.
Apply:
- # if hasattr(self.draft_model_engine.model.model, 'd2t'): - # outputs['d2t'] = self.draft_model_engine.model.model.d2t.data + d2t = getattr(getattr(self.draft_model_engine.model, "model", None), "d2t", None) + if d2t is not None: + outputs['d2t'] = d2t # tensor/parameter; do not use .data
389-397
: Static path is functionally wrong and slow (CPU copy, wrong recipients, leaks resources)Issues:
- outputs.cpu() forces H2D/D2H sync, defeating perf goals.
- Appends tokens for both context and gen requests indiscriminately.
- Ignores per-request allocation limits.
- Leaks draft seq slot resources and misses padding/rollback.
Apply the GPU-friendly fix:
- if self.use_static_draft_loop: - outputs_host = outputs.cpu() - for token_idx in range(self.max_draft_tokens): - for req_idx, req in enumerate(draft_batch.all_requests()): - target_req = req_id_to_old_request[req.py_request_id] - target_req.py_draft_tokens.append( - outputs_host[token_idx][req_idx]) - - return + if self.use_static_draft_loop: + # Accept either Tensor or dict from wrapper + tokens = outputs + if isinstance(tokens, dict): + tokens = tokens.get('next_draft_tokens', None) + assert isinstance(tokens, torch.Tensor), \ + "Static draft loop expects Tensor or {'next_draft_tokens': Tensor}" + # Normalize to [B, T] + if tokens.dim() == 2 and tokens.shape[0] == self.max_draft_tokens: + tokens_bt = tokens.transpose(0, 1).contiguous() # [B, T] + else: + tokens_bt = tokens # assume [B, T] + + # Map only generation requests; contexts come first + ctx_count = len(draft_batch.context_requests) + gen_reqs = draft_batch.generation_requests + for i, req in enumerate(gen_reqs): + target_req = req_id_to_old_request[req.py_request_id] + # Respect allocated capacity + remaining = max( + 0, target_req.py_draft_pages_allocated - len(target_req.py_draft_tokens) + ) + n = min(self.max_draft_tokens, remaining) + if n > 0: + target_req.py_draft_tokens.extend( + tokens_bt[ctx_count + i, :n].tolist() + ) + + # Free per-request resources allocated for this draft pass + for req in draft_batch.all_requests(): + self.draft_seq_slot_manager.free_resources(req) + # Pad and keep invariants for downstream CUDA graph paths + self._pad_to_max_draft_tokens(scheduled_requests) + if self.guided_decoder is not None: + self.guided_decoder.rollback_draft_tokens(scheduled_requests) + return
🧹 Nitpick comments (7)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (3)
259-276
: ChainDrafter wrapper: define clear contract and avoid late-binding hazardsThe conditional wrapper wiring looks good and matches the gating (no guided decoding, no mixed sampler). Two actionable tweaks:
- Capture the max_draft_len by value to avoid relying on the outer spec_config reference at call time.
- Document/standardize the wrapper’s output contract (e.g., return a dict with next_draft_tokens: Tensor[B, T] and optionally d2t) so downstream ModelDrafter can rely on it.
Apply:
- if use_chain_drafter: - - def drafting_loop_wrapper(model): - from tensorrt_llm._torch.speculative.eagle3 import \ - ChainDrafter - - return ChainDrafter(spec_config.max_draft_len, model) + if use_chain_drafter: + max_draft_len = spec_config.max_draft_len # capture by value + + def drafting_loop_wrapper(model): + # Lazy import to keep scope tight + from tensorrt_llm._torch.speculative.eagle3 import ChainDrafter + # Contract: forward(...) should return either: + # - Tensor[T, B] or Tensor[B, T] draft tokens, or + # - {'next_draft_tokens': Tensor[B, T], 'd2t': Optional[Tensor]} + return ChainDrafter(max_draft_len, model)Also applies to: 292-293
24-28
: Duplicate/ambiguous import of is_mla shadows the earlier symbolBoth ._util and .config_utils export is_mla; importing from both is error-prone. Prefer a single source to avoid shadowing and confusion.
Apply one of the following; option A keeps config_utils variant:
-from ._util import (KvCacheCreator, _adjust_torch_mem_fraction, - create_py_executor_instance, instantiate_sampler, is_mla) +from ._util import (KvCacheCreator, _adjust_torch_mem_fraction, + create_py_executor_instance, instantiate_sampler) @@ -from .config_utils import is_mla +from .config_utils import is_mlaOr option B keeps _util variant:
-from ._util import (KvCacheCreator, _adjust_torch_mem_fraction, - create_py_executor_instance, instantiate_sampler, is_mla) +from ._util import (KvCacheCreator, _adjust_torch_mem_fraction, + create_py_executor_instance, instantiate_sampler, is_mla) @@ -from .config_utils import is_mla +# from .config_utils import is_mla # avoid duplicate symbol
1-7
: Missing NVIDIA Apache-2.0 headerPer repo guidelines, prepend the NVIDIA Apache-2.0 header with the current year.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/speculative/model_drafter.py (2)
75-80
: Static draft loop gating is OK; add docstring and explicit opt-in rationaleUsing draft_model_engine.model_is_wrapped to toggle the static path matches the wrapper wiring. Please add a short docstring for use_static_draft_loop describing constraints (no guided decoder, no mixed sampler) and why the mode exists, to aid future maintainers.
Apply:
@@ - self.use_static_draft_loop = draft_model_engine.model_is_wrapped + # When the draft model is wrapped by a "static drafter" (e.g., ChainDrafter), + # we switch to a fast-path that expects the wrapper to produce draft tokens + # without logits. This path currently forbids guided decoding and mixed sampler. + self.use_static_draft_loop = draft_model_engine.model_is_wrapped if self.use_static_draft_loop: # TODO: enable sampling/guided decoding on static draft loop assert guided_decoder is None assert not sampler.enable_mixed_sampler
1-8
: Missing NVIDIA Apache-2.0 headerPer guidelines, prepend the header.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
277-279
: New public parameter: add docstring and usage notesdrafting_loop_wrapper is a public API change. Document its semantics (called post-load, must return torch.nn.Module compatible with DecoderModelForCausalLM conventions or proxy required attributes).
Apply:
class PyTorchModelEngine(ModelEngine): @@ - def __init__( + def __init__( self, @@ - drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], - torch.nn.Module]] = None, + drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], torch.nn.Module]] = None, ): + """ + Args: + drafting_loop_wrapper: Optional callable applied to the loaded model to wrap/replace + the draft model’s decoding loop (e.g., ChainDrafter). The wrapper must return a + torch.nn.Module that either: + - subclasses DecoderModelForCausalLM, or + - proxies attributes and methods used by PyTorchModelEngine (config, model, + model_config, forward, load_weights_from_target_model, etc.). + """
1-15
: Missing NVIDIA Apache-2.0 headerPer guidelines, prepend the header.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
tensorrt_llm/_torch/models/modeling_speculative.py
(1 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py
(6 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(2 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(1 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tensorrt_llm/_torch/models/modeling_speculative.py
- tensorrt_llm/_torch/speculative/eagle3.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent with 4 spaces; do not use tabs
Preserve module namespace when importing: from package.subpackage import foo; then use foo.SomeClass()
Python filenames use snake_case (e.g., some_file.py)
Class names use PascalCase
Function and method names use snake_case
Local variables use snake_case; prefix k for names starting with a number (e.g., k_99th_percentile)
Global variables are UPPER_SNAKE_CASE prefixed with G (e.g., G_MY_GLOBAL)
Constants are UPPER_SNAKE_CASE
Avoid shadowing variables from an outer scope
Initialize all externally visible members of a class in init
For interfaces used outside a file, prefer docstrings over comments; comments for internal code or local interfaces
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Attributes and variables can be documented inline with trailing docstrings under the class or module
Avoid using reflection when easily avoidable; prefer explicit parameters/constructs over dict(**locals())
In try/except, catch the narrowest exception types possible
For duck-typing try/except, keep try body minimal and place logic in else after attribute existence checks
Files:
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
**/*.{h,hpp,hxx,hh,c,cc,cpp,cxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA Apache-2.0 copyright header with current year to all source files
Files:
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/pyexecutor/model_engine.py
🧬 Code graph analysis (2)
tensorrt_llm/_torch/speculative/model_drafter.py (2)
tensorrt_llm/runtime/generation.py (1)
max_draft_tokens
(1283-1286)tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
all_requests
(38-39)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
drafting_loop_wrapper
(265-269)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (4)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
273-276
: Dummy load format handling for draft engine is correct.The cloned backend config and conditional DUMMY load format for the draft engine preserves main-engine settings while respecting dummy spec loads. No concerns.
tensorrt_llm/_torch/speculative/model_drafter.py (1)
246-248
: CUDA graph disable heuristic changed — verify interactions with draft wrappersAlways returning False when use_static_draft_loop means we’ll attempt CUDA graphs even when KV recompute might be needed. Ensure ChainDrafter’s forward does not require disabling CUDA graphs; otherwise capture/replay may be invalid.
You can sanity-check by running a warmup that captures graphs with a wrapped draft model and asserting no recaptures or runtime graph invalidation occur (look for warnings/errors about graph invalidation in logs).
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
407-409
: without_logits widened for wrapped models — confirm downstream expectationsThis forces the draft path to return non-logits outputs when wrapped. Ensure all consumers (e.g., ModelDrafter) handle dict vs tensor outputs robustly (your ModelDrafter static-path fix above does). Consider adding a small unit test to lock this behavior.
Would you like me to draft a minimal test that wraps a tiny nn.Module with a fake ChainDrafter and validates the end-to-end shape contract?
1014-1019
: Enforce wrapper contract immediately after application
After you applydrafting_loop_wrapper
, add a runtime check that the returned object is still annn.Module
and that it exposes every attribute the engine relies on (config
,model_config
,forward
,load_weights_from_target_model
). If any are missing, raise a clear error listing exactly which attributes weren’t preserved.Suggested diff:
if drafting_loop_wrapper is not None: - model = drafting_loop_wrapper(model) - self.model_is_wrapped = True + wrapped = drafting_loop_wrapper(model) + # === Validate that the wrapper preserved the required interface === + required = ("config", "model_config", "forward", "load_weights_from_target_model") + missing = [attr for attr in required if not hasattr(wrapped, attr)] + if missing: + raise TypeError( + f"Drafting wrapper must return nn.Module with attributes " + f"{required!r}; missing: {missing}" + ) + model = wrapped + self.model_is_wrapped = True else: self.model_is_wrapped = False• Ensures the engine won’t later hit obscure
AttributeError
s when accessingself.model.config
,self.model.forward
, etc.
• Points developers to the exact wrapper implementation (e.g.ChainDrafter
) so they can confirm it passes through or delegates each attribute correctly.
• Provides an immediate, user-friendly error if the contract is violated, speeding up debugging.
37fb26e
to
4721451
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
♻️ Duplicate comments (2)
tensorrt_llm/_torch/speculative/model_drafter.py (2)
267-271
: Fix the d2t handling - use safe tensor access patterns.The commented guidance from past reviews was correct - avoid using
.data
which bypasses autograd safety.-# Handle d2t data if available. Static drafting loops should incorporate d2t -# in their implementations. -if not self.use_static_draft_loop and hasattr( - self.draft_model_engine.model.model, 'd2t'): - outputs['d2t'] = self.draft_model_engine.model.model.d2t.data +# Handle d2t data if available. Static drafting loops should incorporate d2t +# in their implementations. +if not self.use_static_draft_loop: + d2t = getattr(getattr(self.draft_model_engine.model, "model", None), "d2t", None) + if d2t is not None and isinstance(d2t, torch.Tensor): + outputs['d2t'] = d2t.detach() # Safe tensor access without .data
391-400
: Critical: The static draft path bypasses essential processing steps.The current implementation has several issues:
- Moving outputs to CPU defeats performance goals
- Indiscriminately appends tokens for all requests (context and generation)
- Ignores per-request allocation limits
- Skips resource cleanup and padding
The static draft path needs proper handling of:
- Generation vs context requests distinction
- Per-request draft token allocation limits
- Resource cleanup
- Output padding for CUDA graphs
Consider implementing a proper fast path or removing this early return until the ChainDrafter can handle these requirements correctly.
🧹 Nitpick comments (5)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
259-272
: Consider making the ChainDrafter opt-in instead of implicit.The current logic automatically enables ChainDrafter based on certain conditions. Consider making this an explicit configuration option for better control and clarity.
The automatic enabling based on
guided_decoding_config is None and not enable_mixed_sampler
may not be intuitive for users. Consider adding an explicit configuration flag likeuse_chain_drafter
toPyTorchConfig
that defaults toFalse
.tensorrt_llm/_torch/speculative/model_drafter.py (1)
75-80
: Document the static draft loop limitations more clearly.The assertions enforce constraints but the TODO comment doesn't fully explain why these limitations exist or when they'll be addressed.
Add a more detailed comment explaining:
- Why guided decoding is incompatible with the static draft loop
- Why mixed sampling is disabled
- Timeline or conditions for enabling these features
self.use_static_draft_loop = draft_model_engine.model_is_wrapped if self.use_static_draft_loop: - # TODO: enable sampling/guided decoding on static draft loop + # Static draft loop currently doesn't support: + # 1. Guided decoding - requires per-token vocabulary masking + # 2. Mixed sampling - requires logits processing between iterations + # These will be enabled once the ChainDrafter supports these features assert guided_decoder is None assert not sampler.enable_mixed_samplertensorrt_llm/_torch/speculative/drafting_loops.py (3)
19-36
: Consider error handling for metadata state restoration.The context manager correctly saves and restores state, but if an exception occurs during cloning or copying, the metadata could be left in an inconsistent state.
Consider adding error handling to ensure metadata consistency:
@contextmanager def save_metadata_state(attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata) -> None: batch_size = attn_metadata.num_seqs + seq_len = None + seq_len_cuda = None if attn_metadata.is_cuda_graph: - seq_len = attn_metadata._seq_lens[:batch_size].clone() - seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone() + try: + seq_len = attn_metadata._seq_lens[:batch_size].clone() + seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone() + except Exception as e: + # Log error and re-raise to prevent silent failures + raise RuntimeError(f"Failed to save metadata state: {e}") from e try: yield finally: - if attn_metadata.is_cuda_graph: + if attn_metadata.is_cuda_graph and seq_len is not None: attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size]) attn_metadata._seq_lens_cuda[:batch_size].copy_( seq_len_cuda[:batch_size]) spec_metadata.reset()
47-48
: Magic number 1 should be documented.The value
1
being filled intohost_request_types
appears to represent generation requests, but this isn't documented.Add a comment explaining the meaning:
+ # Set all requests to generation type (1) attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1)
67-75
: Add validation for draft_model requirements.The ChainDrafter assumes the draft_model has certain attributes (config, model_config) but doesn't validate them.
def __init__(self, max_draft_len: int, draft_model: torch.nn.Module): super().__init__() + if not hasattr(draft_model, 'config'): + raise ValueError("draft_model must have a 'config' attribute") + if not hasattr(draft_model, 'model_config'): + raise ValueError("draft_model must have a 'model_config' attribute") self.draft_model = draft_model self.config = self.draft_model.config self.model_config = self.draft_model.model_config self.max_draft_len = max_draft_len
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (7)
tensorrt_llm/_torch/models/modeling_speculative.py
(1 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py
(6 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
(2 hunks)tensorrt_llm/_torch/speculative/drafting_loops.py
(1 hunks)tensorrt_llm/_torch/speculative/eagle3.py
(1 hunks)tensorrt_llm/_torch/speculative/interface.py
(1 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- tensorrt_llm/_torch/models/modeling_speculative.py
- tensorrt_llm/_torch/speculative/eagle3.py
- tensorrt_llm/_torch/pyexecutor/model_engine.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent with 4 spaces; do not use tabs
Preserve module namespace when importing: from package.subpackage import foo; then use foo.SomeClass()
Python filenames use snake_case (e.g., some_file.py)
Class names use PascalCase
Function and method names use snake_case
Local variables use snake_case; prefix k for names starting with a number (e.g., k_99th_percentile)
Global variables are UPPER_SNAKE_CASE prefixed with G (e.g., G_MY_GLOBAL)
Constants are UPPER_SNAKE_CASE
Avoid shadowing variables from an outer scope
Initialize all externally visible members of a class in init
For interfaces used outside a file, prefer docstrings over comments; comments for internal code or local interfaces
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Attributes and variables can be documented inline with trailing docstrings under the class or module
Avoid using reflection when easily avoidable; prefer explicit parameters/constructs over dict(**locals())
In try/except, catch the narrowest exception types possible
For duck-typing try/except, keep try body minimal and place logic in else after attribute existence checks
Files:
tensorrt_llm/_torch/speculative/interface.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/speculative/drafting_loops.py
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
**/*.{h,hpp,hxx,hh,c,cc,cpp,cxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA Apache-2.0 copyright header with current year to all source files
Files:
tensorrt_llm/_torch/speculative/interface.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/speculative/drafting_loops.py
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
🧬 Code graph analysis (3)
tensorrt_llm/_torch/speculative/interface.py (1)
tensorrt_llm/_torch/speculative/eagle3.py (1)
reset
(188-190)
tensorrt_llm/_torch/speculative/model_drafter.py (2)
tensorrt_llm/runtime/generation.py (1)
max_draft_tokens
(1283-1286)tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
all_requests
(38-39)
tensorrt_llm/_torch/speculative/drafting_loops.py (3)
tensorrt_llm/_torch/attention_backend/interface.py (4)
AttentionMetadata
(39-328)num_seqs
(252-256)on_update
(151-161)seq_lens_cuda
(222-223)tensorrt_llm/_torch/speculative/interface.py (2)
SpecMetadata
(116-217)reset
(213-217)tensorrt_llm/_torch/speculative/eagle3.py (3)
Eagle3SpecMetadata
(81-190)reset
(188-190)forward
(285-396)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (4)
tensorrt_llm/_torch/speculative/interface.py (1)
213-217
: LGTM!The addition of the
reset()
hook provides a clean interface for resetting speculative metadata state after drafting loops, with a sensible no-op default implementation.tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
273-276
: Preserve copy semantics for clarity.Creating a copy before modification improves code clarity and prevents potential side effects.
Good practice to copy the configuration before modifying it to preserve the original state.
tensorrt_llm/_torch/speculative/model_drafter.py (1)
246-248
: LGTM - Correct CUDA graph handling for static draft loops.Skipping CUDA graph disabling for static draft loops is appropriate since the wrapped model handles its own CUDA graph state management.
tensorrt_llm/_torch/speculative/drafting_loops.py (1)
105-108
: Verify the conditional logic for hidden states indices.The hidden states indices are only updated when
i == 0
, which seems intentional but could use a comment explaining why this only happens on the first iteration.Add a comment explaining the logic:
attn_metadata.kv_lens_cuda[:batch_size] += 1 if i == 0: + # After the first draft token, all subsequent tokens read from the same indices spec_metadata.hidden_states_read_indices[:batch_size].copy_( spec_metadata.hidden_states_write_indices[:batch_size])
PR_Github #17006 [ run ] triggered by Bot |
PR_Github #17003 [ run ] completed with state |
PR_Github #17006 [ run ] completed with state |
ea2f804
to
93492ce
Compare
/bot run --disable-fail-fast |
PR_Github #17016 [ run ] triggered by Bot |
/bot run |
PR_Github #17021 [ run ] triggered by Bot |
PR_Github #17016 [ run ] completed with state |
/bot run |
PR_Github #17021 [ run ] completed with state |
PR_Github #17024 [ run ] triggered by Bot |
/bot run |
PR_Github #17026 [ run ] triggered by Bot |
PR_Github #17024 [ run ] completed with state |
PR_Github #17026 [ run ] completed with state |
/bot run |
PR_Github #17033 [ run ] triggered by Bot |
PR_Github #17033 [ run ] completed with state |
/bot run |
PR_Github #17038 [ run ] triggered by Bot |
/bot run |
PR_Github #17045 [ run ] triggered by Bot |
PR_Github #17038 [ run ] completed with state |
PR_Github #17045 [ run ] completed with state |
/bot run |
PR_Github #17055 [ run ] triggered by Bot |
Description
Introduce a new concept called capturable drafting loops (CDLs) to the 2-model speculation infrastructure. The goal of this approach is to get the performance of 1-model speculation while keeping the modularity and flexibility of 2-model. This is achieved by doing the following:
ModelEngine
and launched as a single CUDA graph.This is very much like the
SpecWorker
abstraction in the 1-model implementation, except it does not include verification. The advantages compared to standard 1-model are:Drafter
abstraction. No more divergent code paths.To get the best performance, we need to finish overlap scheduling and also support CUDA graph for the first eagle3 iteration (which may include recomputation of KV cache for accepted tokens)
Test Coverage
Covered by existing eagle3 2-model tests.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
Summary by CodeRabbit
New Features
Performance
API Changes
Compatibility