|
8 | 8 | from tensorrt_llm._utils import nvtx_range
|
9 | 9 | from tensorrt_llm.logger import logger
|
10 | 10 |
|
11 |
| -from ..pyexecutor.llm_request import LlmRequest, LlmRequestState, SamplingConfig |
| 11 | +from ..pyexecutor.llm_request import (LlmRequest, LlmRequestState, |
| 12 | + SamplingConfig, get_draft_token_length) |
12 | 13 | from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager
|
13 | 14 | from ..pyexecutor.sampler import Sampler, SampleState
|
14 | 15 | from ..pyexecutor.scheduler import ScheduledRequests
|
@@ -59,7 +60,6 @@ def __init__(
|
59 | 60 | # Configuration
|
60 | 61 | self.spec_config = spec_config
|
61 | 62 | self.max_draft_tokens = max_draft_tokens
|
62 |
| - |
63 | 63 | # Sampling
|
64 | 64 | self.sampler = sampler
|
65 | 65 |
|
@@ -214,7 +214,6 @@ def _prepare_draft_batch(
|
214 | 214 | if request.py_draft_pages_allocated == 0:
|
215 | 215 | # No space for draft tokens
|
216 | 216 | continue
|
217 |
| - |
218 | 217 | # Stop drafting when we hit the max seqlen. We still need dummy draft
|
219 | 218 | # tokens attached to the requests to make sure everything works properly
|
220 | 219 | # with CUDA graph. These dummy tokens are already added by
|
@@ -320,7 +319,7 @@ def _pad_to_max_draft_tokens(self,
|
320 | 319 | """Pad draft tokens to maximum length for all generation requests."""
|
321 | 320 | for req in scheduled_requests.generation_requests:
|
322 | 321 | max_draft_tokens = self.max_draft_tokens
|
323 |
| - num_draft_tokens = len(req.py_draft_tokens) |
| 322 | + num_draft_tokens = get_draft_token_length(req) |
324 | 323 | req.py_draft_tokens.extend(
|
325 | 324 | 0 for _ in range(max_draft_tokens - num_draft_tokens))
|
326 | 325 |
|
|
0 commit comments