Skip to content

Commit 1a42cf9

Browse files
committed
[feat] Improve 2-model perf
Signed-off-by: Mike Iovine <[email protected]>
1 parent 907bc22 commit 1a42cf9

File tree

6 files changed

+120
-24
lines changed

6 files changed

+120
-24
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -290,18 +290,6 @@ def load_weights_from_target_model(self,
290290
if self.load_lm_head_from_target:
291291
self.lm_head = target_model.lm_head
292292

293-
# TODO: should input/position IDs be included in this? Keeping it implicit
294-
# for now since the shapes/dtypes are the same across all models we have.
295-
def get_warmup_extra_inputs(self, batch_size: int,
296-
num_tokens: int) -> Dict[str, Any]:
297-
298-
hidden_states = torch.empty(batch_size * num_tokens,
299-
self.model.hidden_size,
300-
dtype=self.model.dtype,
301-
device='cuda')
302-
303-
return {'hidden_states': hidden_states}
304-
305293
def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
306294
"""
307295
Hack for eagle3. We might need to run a matmul to reduce

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def __init__(
388388
self.spec_metadata = None
389389
update_spec_config_from_model_config(self.spec_config,
390390
self.model.config)
391-
max_num_draft_tokens = self.spec_config.max_draft_len * batch_size
391+
max_num_draft_tokens = self.spec_config.max_draft_len * batch_size if not self.is_draft_model else 0
392392
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
393393
dtype=torch.int,
394394
device='cuda')
@@ -402,9 +402,11 @@ def __init__(
402402
self.previous_kv_lens_offsets_cuda = torch.zeros((batch_size, ),
403403
dtype=torch.int,
404404
device='cuda')
405+
# TODO undo this hack
405406
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
406-
)
407-
self.max_draft_len = spec_config.max_draft_len
407+
) or (self.is_draft_model
408+
and self.spec_config.spec_dec_mode.is_eagle3())
409+
self.max_draft_len = spec_config.max_draft_len if not self.is_draft_model else 0
408410
else:
409411
self.without_logits = False
410412
self.max_draft_len = 0
@@ -466,7 +468,7 @@ def __init__(
466468

467469
@property
468470
def runtime_draft_len(self):
469-
return self.max_draft_len if self.enable_spec_decode else 0
471+
return self.max_draft_len if self.enable_spec_decode and not self.is_draft_model else 0
470472

471473
def set_lora_model_config(self,
472474
lora_target_modules: list[str],
@@ -989,7 +991,7 @@ def _maybe_get_cuda_graph(
989991
if ExpertStatistic.set_iter(self.iter_counter):
990992
return None
991993

992-
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
994+
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode and not self.is_draft_model else 0
993995
can_run_cuda_graph = batch.can_run_cuda_graph
994996
batch_size = len(batch.generation_requests)
995997
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
@@ -1022,6 +1024,8 @@ def _maybe_get_cuda_graph(
10221024
if self.enable_spec_decode:
10231025
spec_metadata = self.spec_metadata.create_cuda_graph_metadata(
10241026
num_sequences_in_batch)
1027+
if self.is_draft_model:
1028+
spec_metadata.max_draft_len = 0
10251029
spec_metadata.draft_tokens = self.draft_tokens_cuda
10261030
else:
10271031
spec_metadata = None
@@ -1153,6 +1157,9 @@ def init_meta_tensor(t: torch.Tensor):
11531157
logger.info("moe_load_balancer finalize model done")
11541158

11551159
torch.cuda.current_stream().synchronize()
1160+
if self.spec_config is not None and self.is_draft_model:
1161+
model = self.spec_config.get_draft_model_wrapper(model) or model
1162+
11561163
return model
11571164

11581165
def _call_load_weights(self, load_method, weights, weight_mapper):
@@ -1411,7 +1418,7 @@ def _prepare_tp_inputs(
14111418
past_seen_token_num = request.max_beam_num_tokens - 1
14121419
draft_lens.append(num_draft_tokens)
14131420

1414-
if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx(
1421+
if self.enable_spec_decode and not self.is_draft_model and spec_config.spec_dec_mode.extend_ctx(
14151422
self.attn_backend):
14161423
# We're treating the prompt lengths as context requests here, so
14171424
# the the prompt lens should not include the cached tokens.

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,6 @@ def create_py_executor(
252252
with mem_monitor.observe_creation_stage(
253253
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
254254
draft_spec_config = copy.copy(spec_config)
255-
# The draft model won't have any draft tokens attached to
256-
# generation requests when we invoke it autoregressively
257-
draft_spec_config.max_draft_len = 0
258255

259256
draft_model_engine = PyTorchModelEngine(
260257
model_path=spec_config.speculative_model_dir,

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import List, Optional, Tuple
2+
from typing import List, Optional, Tuple, Any, Dict
33

44
import torch
55
from torch import nn
@@ -493,3 +493,86 @@ def prepare_1st_drafter_inputs(
493493
"attn_metadata": attn_metadata,
494494
"spec_metadata": spec_metadata,
495495
}
496+
497+
498+
class ChainDrafter(torch.nn.Module):
499+
500+
def __init__(self, max_draft_len: int, draft_model: torch.nn.Module):
501+
super().__init__()
502+
self.draft_model = draft_model
503+
self.config = self.draft_model.config
504+
self.model_config = self.draft_model.model_config
505+
self.max_draft_len = max_draft_len
506+
507+
def forward(self, input_ids, position_ids, attn_metadata, spec_metadata,
508+
**kwargs):
509+
batch_size = attn_metadata.num_seqs
510+
511+
logits = self.draft_model.forward(input_ids=input_ids,
512+
position_ids=position_ids,
513+
attn_metadata=attn_metadata,
514+
spec_metadata=spec_metadata)
515+
516+
new_draft_tokens = [self.sample(logits)]
517+
518+
if attn_metadata.is_cuda_graph:
519+
seq_len = attn_metadata._seq_lens[:batch_size].clone()
520+
seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone()
521+
522+
last_tokens_idx = torch.cumsum(
523+
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
524+
new_position_ids = position_ids[0, last_tokens_idx] + 1
525+
526+
attn_metadata._seq_lens[:batch_size].fill_(1)
527+
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
528+
attn_metadata.on_update()
529+
attn_metadata.kv_lens_cuda[:batch_size] += 1
530+
531+
attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1)
532+
attn_metadata.num_contexts = 0
533+
534+
spec_metadata.eagle3_resource_manager.is_first_draft = False
535+
spec_metadata.is_first_draft = False
536+
537+
old_write_indices = spec_metadata.hidden_states_write_indices
538+
539+
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
540+
old_write_indices[last_tokens_idx])
541+
spec_metadata.hidden_states_write_indices[:batch_size].copy_(
542+
torch.arange(
543+
batch_size,
544+
dtype=spec_metadata.hidden_states_write_indices.dtype,
545+
device=spec_metadata.hidden_states_write_indices.device))
546+
spec_metadata.num_tokens = batch_size
547+
548+
for i in range(self.max_draft_len - 1):
549+
logits = self.draft_model.forward(input_ids=new_draft_tokens[-1],
550+
position_ids=new_position_ids,
551+
attn_metadata=attn_metadata,
552+
spec_metadata=spec_metadata)
553+
new_draft_tokens.append(self.sample(logits))
554+
new_position_ids += 1
555+
attn_metadata.kv_lens_cuda[:batch_size] += 1
556+
if i == 0:
557+
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
558+
spec_metadata.hidden_states_write_indices[:batch_size])
559+
560+
spec_metadata.is_first_draft = True
561+
spec_metadata.eagle3_resource_manager.is_first_draft = True
562+
563+
if attn_metadata.is_cuda_graph:
564+
attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size])
565+
attn_metadata._seq_lens_cuda[:batch_size].copy_(
566+
seq_len_cuda[:batch_size])
567+
568+
return torch.stack(new_draft_tokens)
569+
570+
def sample(self, logits: torch.Tensor) -> torch.Tensor:
571+
tokens = torch.argmax(logits, dim=-1)
572+
d2t = self.draft_model.model.d2t.data
573+
574+
return tokens + d2t[tokens]
575+
576+
def load_weights_from_target_model(self,
577+
target_model: torch.nn.Module) -> None:
578+
self.draft_model.load_weights_from_target_model(target_model)

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def __init__(
7272
self._request_draft_logits = sampler.enable_mixed_sampler
7373
self.guided_decoder = guided_decoder
7474

75+
self.use_static_draft_loop = True
76+
7577
def _create_draft_request(self, request: LlmRequest,
7678
input_tokens: Optional[List]) -> LlmRequest:
7779
"""Create a draft request with common parameters."""
@@ -257,8 +259,8 @@ def _forward_draft_model(
257259
new_tensors_device=new_tensors_device)
258260

259261
# Handle d2t data if available
260-
if hasattr(self.draft_model_engine.model.model, 'd2t'):
261-
outputs['d2t'] = self.draft_model_engine.model.model.d2t.data
262+
# if hasattr(self.draft_model_engine.model.model, 'd2t'):
263+
# outputs['d2t'] = self.draft_model_engine.model.model.d2t.data
262264

263265
return outputs
264266

@@ -377,6 +379,17 @@ def prepare_draft_tokens(
377379

378380
# Initial forward pass
379381
outputs = self._forward_draft_model(draft_batch, resource_manager)
382+
383+
if self.use_static_draft_loop:
384+
outputs_host = outputs.cpu()
385+
for token_idx in range(self.max_draft_tokens):
386+
for req_idx, req in enumerate(draft_batch.all_requests()):
387+
target_req = req_id_to_old_request[req.py_request_id]
388+
target_req.py_draft_tokens.append(
389+
outputs_host[token_idx][req_idx])
390+
391+
return
392+
380393
self._execute_guided_decoder(draft_batch,
381394
outputs['logits'],
382395
d2t=outputs.get('d2t'))

tensorrt_llm/llmapi/llm_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,9 @@ def spec_dec_mode(self):
400400
return TorchSpeculativeDecodingMode.from_string(
401401
self.decoding_type.upper())
402402

403+
def get_draft_model_wrapper(self, model):
404+
return None
405+
403406

404407
class MedusaDecodingConfig(DecodingBaseConfig):
405408
medusa_choices: Optional[List[List[int]]] = None
@@ -443,6 +446,11 @@ def spec_dec_mode(self):
443446
return TorchSpeculativeDecodingMode.EAGLE3_ONE_MODEL
444447
return TorchSpeculativeDecodingMode.EAGLE3
445448

449+
def get_draft_model_wrapper(self, model):
450+
from tensorrt_llm._torch.speculative.eagle3 import ChainDrafter
451+
452+
return ChainDrafter(self.max_draft_len, model)
453+
446454

447455
class UserProvidedDecodingConfig(DecodingBaseConfig):
448456
# Cannot use real type annotations due to circular imports

0 commit comments

Comments
 (0)