|
23 | 23 | from tensorrt_llm.mapping import Mapping
|
24 | 24 |
|
25 | 25 | from .finish_reason import FinishedState
|
26 |
| -from .llm_request import LlmRequest, LlmRequestState |
| 26 | +from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length |
27 | 27 | from .scheduler import ScheduledRequests
|
28 | 28 |
|
29 | 29 |
|
@@ -337,7 +337,7 @@ def update_requests(self, state: SampleState) -> None:
|
337 | 337 | new_token = add_token(req, new_tokens, beam=self.BEAM)
|
338 | 338 | stop = self._handle_stop_criteria(req, new_token)
|
339 | 339 | processed = 1
|
340 |
| - if not stop and len(req.py_draft_tokens) > 0: |
| 340 | + if not stop and get_draft_token_length(req) > 0: |
341 | 341 | num_accepted = self.process_draft_tokens(
|
342 | 342 | req, new_tokens, new_token)
|
343 | 343 | req.py_num_accepted_draft_tokens = num_accepted
|
@@ -401,7 +401,7 @@ def _process_requests(self,
|
401 | 401 | beam_width = self.MAX_BEAM_WIDTH
|
402 | 402 | beam = self.BEAM
|
403 | 403 | raw_logits = model_outputs["logits"]
|
404 |
| - num_steps = [1 + len(req.py_draft_tokens) for req in requests] |
| 404 | + num_steps = [1 + get_draft_token_length(req) for req in requests] |
405 | 405 | sum_steps = sum(num_steps)
|
406 | 406 | no_draft_tokens = len(requests) == sum_steps
|
407 | 407 | fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
|
|
0 commit comments