Skip to content

Commit cdc0a2a

Browse files
committed
[feat] Improve 2-model perf
Signed-off-by: Mike Iovine <[email protected]>
1 parent 907bc22 commit cdc0a2a

File tree

5 files changed

+113
-11
lines changed

5 files changed

+113
-11
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def __init__(
388388
self.spec_metadata = None
389389
update_spec_config_from_model_config(self.spec_config,
390390
self.model.config)
391-
max_num_draft_tokens = self.spec_config.max_draft_len * batch_size
391+
max_num_draft_tokens = self.spec_config.max_draft_len * batch_size if not self.is_draft_model else 0
392392
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
393393
dtype=torch.int,
394394
device='cuda')
@@ -402,9 +402,11 @@ def __init__(
402402
self.previous_kv_lens_offsets_cuda = torch.zeros((batch_size, ),
403403
dtype=torch.int,
404404
device='cuda')
405+
# TODO undo this hack
405406
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
406-
)
407-
self.max_draft_len = spec_config.max_draft_len
407+
) or (self.is_draft_model
408+
and self.spec_config.spec_dec_mode.is_eagle3())
409+
self.max_draft_len = spec_config.max_draft_len if not self.is_draft_model else 0
408410
else:
409411
self.without_logits = False
410412
self.max_draft_len = 0
@@ -1153,6 +1155,9 @@ def init_meta_tensor(t: torch.Tensor):
11531155
logger.info("moe_load_balancer finalize model done")
11541156

11551157
torch.cuda.current_stream().synchronize()
1158+
if self.spec_config is not None and self.is_draft_model:
1159+
model = self.spec_config.get_draft_model_wrapper(model) or model
1160+
11561161
return model
11571162

11581163
def _call_load_weights(self, load_method, weights, weight_mapper):

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,6 @@ def create_py_executor(
252252
with mem_monitor.observe_creation_stage(
253253
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
254254
draft_spec_config = copy.copy(spec_config)
255-
# The draft model won't have any draft tokens attached to
256-
# generation requests when we invoke it autoregressively
257-
draft_spec_config.max_draft_len = 0
258255

259256
draft_model_engine = PyTorchModelEngine(
260257
model_path=spec_config.speculative_model_dir,

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import List, Optional, Tuple
2+
from typing import List, Optional, Tuple, Any, Dict
33

44
import torch
55
from torch import nn
@@ -176,8 +176,6 @@ def maybe_capture_hidden_states(
176176
def get_hidden_states(self):
177177
hidden_states = self.eagle3_resource_manager.hidden_states[
178178
self.hidden_states_read_indices[:self.num_tokens], :]
179-
if not self.is_first_draft:
180-
hidden_states = hidden_states[:, :self.hidden_size]
181179
return hidden_states
182180

183181

@@ -493,3 +491,84 @@ def prepare_1st_drafter_inputs(
493491
"attn_metadata": attn_metadata,
494492
"spec_metadata": spec_metadata,
495493
}
494+
495+
496+
class ChainDrafter(torch.nn.Module):
497+
498+
def __init__(self, max_draft_len: int, draft_model: torch.nn.Module):
499+
super().__init__()
500+
self.draft_model = draft_model
501+
self.config = self.draft_model.config
502+
self.model_config = self.draft_model.model_config
503+
self.max_draft_len = max_draft_len
504+
505+
def forward(self, input_ids, position_ids, attn_metadata, spec_metadata,
506+
**kwargs):
507+
batch_size = attn_metadata.num_seqs
508+
509+
logits = self.draft_model.forward(input_ids=input_ids,
510+
position_ids=position_ids,
511+
attn_metadata=attn_metadata,
512+
spec_metadata=spec_metadata)
513+
514+
new_draft_tokens = [self.sample(logits)]
515+
516+
if attn_metadata.is_cuda_graph:
517+
seq_len = attn_metadata._seq_lens[:batch_size].clone()
518+
seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone()
519+
520+
old_seqlens = torch.cumsum(attn_metadata.seq_lens_cuda,
521+
dim=0,
522+
dtype=torch.long)
523+
524+
last_tokens_idx = torch.cumsum(
525+
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
526+
new_position_ids = position_ids[0, last_tokens_idx] + 1
527+
528+
attn_metadata._seq_lens[:batch_size].fill_(1)
529+
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
530+
attn_metadata.on_update()
531+
532+
attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1)
533+
attn_metadata.num_contexts = 0
534+
535+
spec_metadata.eagle3_resource_manager.is_first_draft = False
536+
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
537+
old_seqlens - 1)
538+
spec_metadata.hidden_states_write_indices[:batch_size].copy_(
539+
torch.arange(
540+
batch_size,
541+
dtype=spec_metadata.hidden_states_write_indices.dtype,
542+
device=spec_metadata.hidden_states_write_indices.device))
543+
spec_metadata.num_tokens = batch_size
544+
545+
for i in range(self.max_draft_len - 1):
546+
logits = self.draft_model.forward(input_ids=new_draft_tokens[-1],
547+
position_ids=new_position_ids,
548+
attn_metadata=attn_metadata,
549+
spec_metadata=spec_metadata)
550+
new_draft_tokens.append(self.sample(logits))
551+
new_position_ids += 1
552+
553+
spec_metadata.eagle3_resource_manager.is_first_draft = True
554+
555+
if attn_metadata.is_cuda_graph:
556+
attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size])
557+
attn_metadata._seq_lens_cuda[:batch_size].copy_(
558+
seq_len_cuda[:batch_size])
559+
560+
return torch.stack(new_draft_tokens)
561+
562+
def sample(self, logits: torch.Tensor) -> torch.Tensor:
563+
tokens = torch.argmax(logits, dim=-1)
564+
d2t = self.draft_model.model.d2t.data
565+
566+
return tokens + d2t[tokens]
567+
568+
def get_warmup_extra_inputs(self, batch_size: int,
569+
num_tokens: int) -> Dict[str, Any]:
570+
return self.draft_model.get_warmup_extra_inputs(batch_size, num_tokens)
571+
572+
def load_weights_from_target_model(self,
573+
target_model: torch.nn.Module) -> None:
574+
self.draft_model.load_weights_from_target_model(target_model)

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def __init__(
7272
self._request_draft_logits = sampler.enable_mixed_sampler
7373
self.guided_decoder = guided_decoder
7474

75+
self.use_static_draft_loop = True
76+
7577
def _create_draft_request(self, request: LlmRequest,
7678
input_tokens: Optional[List]) -> LlmRequest:
7779
"""Create a draft request with common parameters."""
@@ -257,8 +259,8 @@ def _forward_draft_model(
257259
new_tensors_device=new_tensors_device)
258260

259261
# Handle d2t data if available
260-
if hasattr(self.draft_model_engine.model.model, 'd2t'):
261-
outputs['d2t'] = self.draft_model_engine.model.model.d2t.data
262+
# if hasattr(self.draft_model_engine.model.model, 'd2t'):
263+
# outputs['d2t'] = self.draft_model_engine.model.model.d2t.data
262264

263265
return outputs
264266

@@ -377,6 +379,17 @@ def prepare_draft_tokens(
377379

378380
# Initial forward pass
379381
outputs = self._forward_draft_model(draft_batch, resource_manager)
382+
383+
if self.use_static_draft_loop:
384+
outputs_host = outputs.cpu()
385+
for token_idx in range(self.max_draft_tokens):
386+
for req_idx, req in enumerate(draft_batch.all_requests()):
387+
target_req = req_id_to_old_request[req.py_request_id]
388+
target_req.py_draft_tokens.append(
389+
outputs_host[token_idx][req_idx])
390+
391+
return
392+
380393
self._execute_guided_decoder(draft_batch,
381394
outputs['logits'],
382395
d2t=outputs.get('d2t'))

tensorrt_llm/llmapi/llm_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,9 @@ def spec_dec_mode(self):
400400
return TorchSpeculativeDecodingMode.from_string(
401401
self.decoding_type.upper())
402402

403+
def get_draft_model_wrapper(self, model):
404+
return None
405+
403406

404407
class MedusaDecodingConfig(DecodingBaseConfig):
405408
medusa_choices: Optional[List[List[int]]] = None
@@ -443,6 +446,11 @@ def spec_dec_mode(self):
443446
return TorchSpeculativeDecodingMode.EAGLE3_ONE_MODEL
444447
return TorchSpeculativeDecodingMode.EAGLE3
445448

449+
def get_draft_model_wrapper(self, model):
450+
from tensorrt_llm._torch.speculative.eagle3 import ChainDrafter
451+
452+
return ChainDrafter(self.max_draft_len, model)
453+
446454

447455
class UserProvidedDecodingConfig(DecodingBaseConfig):
448456
# Cannot use real type annotations due to circular imports

0 commit comments

Comments
 (0)