Skip to content

Commit cfa2d42

Browse files
committed
Merge branch 'main' into andrewch-broken-links
Accept feedback from coderabbitai Signed-off-by: Andrew Chen <[email protected]>
2 parents ee7a02a + 8062e0f commit cfa2d42

File tree

13 files changed

+223
-82
lines changed

13 files changed

+223
-82
lines changed

docs/source/advanced/speculative-decoding.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ These tokens are then forwarded to the Target model for verification.
6060
Upon verification, the Target model may return up to `K+1` tokens.
6161
Subsequently, the prompt, now updated with the accepted tokens, is sent back to the Draft model to initiate the generation of new draft tokens.
6262
This iterative process continues until a predefined stop conditions are met.
63-
An example of this orchestration process can be found in the [TensorRT-LLM Triton backend](https://github.com/triton-inference-server/tensorrtllm_backend).
63+
An example orchestration script is available in the Triton backend repository’s
64+
[draft-target-model client example](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/client/python/draft_target_model_client.py).
6465

6566
We provide two styles of running Draft-Target-Model now: using TensorRT-LLM-BLS in Triton Inference Server, or using TensorRT-LLM directly. Detailed steps of running can be found in [examples/draft_target_model/README.md](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/draft_target_model/README.md) and the code can be found in [examples/ngram/run_dtm_ngram.py](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/ngram/run_dtm_ngram.py).
6667

@@ -172,7 +173,7 @@ Similarly to ReDrafter, TensorRT-LLM implements the EAGLE model such that logits
172173

173174
### Disaggregated Serving
174175

175-
[Disaggregated Serving](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/advanced/disaggregated-service.md) with EAGLE3 using the two model approach is supported in the Pytorch backend.
176+
[Disaggregated Serving](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/advanced/disaggregated-service.md) with EAGLE-3 using the two-model approach is supported in the PyTorch backend.
176177

177178
## Lookahead Decoding
178179

docs/source/blogs/Falcon180B-H200.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Often quantization can have adverse impacts on the accuracy of the model,
3333
however, TensorRT-LLM's AWQ decreases memory footprint of the model by **4x**
3434
while maintaining high accuracy.
3535

36-
<img src="https://github.com/NVIDIA/TensorRT-LLM/blob/5aec7af45fc0abd876fa68a9ae8c8cae084f3af3/docs/source/blogs/media/Falcon180B-H200_acc.png" alt="Falcon-180B accuracy comparison" width="600" height="auto">
36+
<img src="https://github.com/NVIDIA/TensorRT-LLM/blob/5aec7af45fc0abd876fa68a9ae8c8cae084f3af3/docs/source/blogs/media/Falcon180B-H200_acc.png?raw=true" alt="Falcon-180B accuracy comparison" width="600" height="auto">
3737

3838

3939
<sup>Preliminary measured accuracy, subject to change. </sup>

docs/source/release-notes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,7 @@ Refer to the {ref}`support-matrix-software` section for a list of supported mode
10451045
- System prompt caching
10461046
- Enabled split-k for weight-only cutlass kernels
10471047
- FP8 KV cache support for XQA kernel
1048-
- New Python builder API and `trtllm-build` command and OPT
1048+
- Added Python builder API, `trtllm-build` command, and OPT support
10491049
- Support `StoppingCriteria` and `LogitsProcessor` in Python generate API
10501050
- FHMA support for chunked attention and paged KV cache
10511051
- Performance enhancements include:

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,3 +477,17 @@ def executor_request_to_llm_request(
477477
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
478478
None))
479479
return llm_request
480+
481+
482+
def get_draft_token_length(request: LlmRequest) -> int:
483+
"""Get the length of draft tokens for a given request.
484+
485+
Args:
486+
request: The LlmRequest to get draft token length for
487+
488+
Returns:
489+
The number of draft tokens, or 0 if no draft tokens exist
490+
"""
491+
if request.py_draft_tokens is not None:
492+
return len(request.py_draft_tokens)
493+
return 0

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 89 additions & 63 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,9 @@ def _prepare_and_schedule_batch(self):
861861
self._pad_attention_dp_dummy_request()
862862

863863
if self.drafter is not None:
864+
self.use_spec_decode = self.drafter.should_use_spec_decode(
865+
self.active_requests)
866+
self.model_engine.enable_spec_decode = self.use_spec_decode
864867
self._prepare_draft_requests(self.active_requests)
865868

866869
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
@@ -922,7 +925,7 @@ def _executor_loop(self):
922925
self._handle_first_token_response(scheduled_batch)
923926

924927
self.resource_manager.prepare_resources(scheduled_batch)
925-
if self.drafter is not None:
928+
if self.drafter is not None and self.use_spec_decode:
926929
self.drafter.prepare_draft_tokens(
927930
scheduled_batch, self.resource_manager)
928931

@@ -973,7 +976,7 @@ def _prepare_draft_requests(self, requests):
973976
req.py_last_draft_tokens = req.py_draft_tokens
974977
max_draft_len = self.model_engine.spec_config.max_draft_len
975978

976-
if max_draft_len > 0:
979+
if max_draft_len > 0 and self.use_spec_decode:
977980
req.py_draft_tokens = [0] * max_draft_len
978981
req.py_draft_pages_allocated = max_draft_len
979982
else:

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from ..._utils import binding_dtype_size, nvtx_range
1616
from ...logger import logger
1717
from ...mapping import Mapping
18-
from .llm_request import LlmRequest, LlmRequestState, SamplingConfig
18+
from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig,
19+
get_draft_token_length)
1920
from .scheduler import ScheduledRequests
2021

2122
if ENABLE_MULTI_DEVICE:
@@ -368,12 +369,12 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
368369
req_beam_width, req)
369370
for _ in range(self.num_extra_kv_tokens):
370371
self.impl.add_token(req.py_request_id)
371-
for _ in range(len(req.py_draft_tokens)):
372+
for _ in range(get_draft_token_length(req)):
372373
self.impl.add_token(req.py_request_id)
373374

374375
for req in generation_batch:
375376
self.impl.add_token(req.py_request_id)
376-
for _ in range(len(req.py_draft_tokens)):
377+
for _ in range(get_draft_token_length(req)):
377378
self.impl.add_token(req.py_request_id)
378379

379380
def add_dummy_requests(

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tensorrt_llm.mapping import Mapping
2424

2525
from .finish_reason import FinishedState
26-
from .llm_request import LlmRequest, LlmRequestState
26+
from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length
2727
from .scheduler import ScheduledRequests
2828

2929

@@ -337,7 +337,7 @@ def update_requests(self, state: SampleState) -> None:
337337
new_token = add_token(req, new_tokens, beam=self.BEAM)
338338
stop = self._handle_stop_criteria(req, new_token)
339339
processed = 1
340-
if not stop and len(req.py_draft_tokens) > 0:
340+
if not stop and get_draft_token_length(req) > 0:
341341
num_accepted = self.process_draft_tokens(
342342
req, new_tokens, new_token)
343343
req.py_num_accepted_draft_tokens = num_accepted
@@ -401,7 +401,7 @@ def _process_requests(self,
401401
beam_width = self.MAX_BEAM_WIDTH
402402
beam = self.BEAM
403403
raw_logits = model_outputs["logits"]
404-
num_steps = [1 + len(req.py_draft_tokens) for req in requests]
404+
num_steps = [1 + get_draft_token_length(req) for req in requests]
405405
sum_steps = sum(num_steps)
406406
no_draft_tokens = len(requests) == sum_steps
407407
fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None

tensorrt_llm/_torch/pyexecutor/scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tensorrt_llm.bindings import executor as tb_executor
66
from tensorrt_llm.bindings import internal as tb_internal
77

8-
from .llm_request import LlmRequest, LlmRequestState
8+
from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length
99

1010
RequestList = list[LlmRequest]
1111

@@ -185,7 +185,7 @@ def schedule(
185185
self, active_requests: RequestList, inflight_request_ids: set[int]
186186
) -> tuple[list[LlmRequest], list[LlmRequest]]:
187187
for request in active_requests:
188-
if len(request.py_draft_tokens) > 0:
188+
if get_draft_token_length(request) > 0:
189189
request.draft_tokens = request.py_draft_tokens
190190
return self.impl(active_requests, inflight_request_ids,
191191
self.max_batch_size, self.max_num_tokens)

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
2-
from typing import Optional
2+
from typing import List, Optional
33

4+
from ..pyexecutor.llm_request import LlmRequest
45
from ..pyexecutor.resource_manager import ResourceManager
56
from ..pyexecutor.scheduler import ScheduledRequests
67

@@ -21,3 +22,7 @@ def prepare_draft_tokens(
2122
scheduled_requests: The scheduled requests for this iteration
2223
"""
2324
raise NotImplementedError
25+
26+
def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
27+
"""Check if spec decode should be used for the current iteration."""
28+
return True

0 commit comments

Comments
 (0)