Skip to content

Commit 0874f2b

Browse files
committed
[feat] Improve 2-model perf
Signed-off-by: Mike Iovine <[email protected]>
1 parent 200db3b commit 0874f2b

File tree

5 files changed

+129
-25
lines changed

5 files changed

+129
-25
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -293,18 +293,6 @@ def load_weights_from_target_model(self,
293293
if self.load_lm_head_from_target:
294294
self.lm_head = target_model.lm_head
295295

296-
# TODO: should input/position IDs be included in this? Keeping it implicit
297-
# for now since the shapes/dtypes are the same across all models we have.
298-
def get_warmup_extra_inputs(self, batch_size: int,
299-
num_tokens: int) -> Dict[str, Any]:
300-
301-
hidden_states = torch.empty(batch_size * num_tokens,
302-
self.model.hidden_size,
303-
dtype=self.model.dtype,
304-
device='cuda')
305-
306-
return {'hidden_states': hidden_states}
307-
308296
def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
309297
"""
310298
Hack for eagle3. We might need to run a matmul to reduce

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import weakref
1010
from abc import ABC, abstractmethod
1111
from contextlib import contextmanager
12-
from typing import Any, Dict, Optional, Tuple
12+
from typing import Any, Dict, Optional, Tuple, Callable
1313

1414
import torch
1515
import torch._dynamo.config
@@ -274,6 +274,8 @@ def __init__(
274274
spec_config: Optional["DecodingBaseConfig"] = None,
275275
lora_config: Optional[LoraConfig] = None,
276276
is_draft_model: bool = False,
277+
model_wrapper: Optional[Callable[[torch.nn.Module],
278+
torch.nn.Module]] = None,
277279
):
278280
self.ub_buffers = None
279281
self.batch_size = batch_size
@@ -309,7 +311,8 @@ def __init__(
309311
max_num_tokens=max_num_tokens,
310312
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,
311313
moe_load_balancer=pytorch_backend_config.moe_load_balancer,
312-
lora_config=lora_config)
314+
lora_config=lora_config,
315+
model_wrapper=model_wrapper)
313316
# In case that some tests use stub models and override `_load_model`.
314317
if not hasattr(self.model, 'extra_attrs'):
315318
self.model.extra_attrs = {}
@@ -387,7 +390,7 @@ def __init__(
387390
self.spec_metadata = None
388391
update_spec_config_from_model_config(self.spec_config,
389392
self.model.config)
390-
max_num_draft_tokens = self.spec_config.max_draft_len * batch_size
393+
max_num_draft_tokens = self.spec_config.max_draft_len * batch_size if not self.is_draft_model else 0
391394
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
392395
dtype=torch.int,
393396
device='cuda')
@@ -401,9 +404,11 @@ def __init__(
401404
self.previous_kv_lens_offsets_cuda = torch.zeros((batch_size, ),
402405
dtype=torch.int,
403406
device='cuda')
407+
# TODO undo this hack
404408
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
405-
)
406-
self.max_draft_len = spec_config.max_draft_len
409+
) or (self.is_draft_model
410+
and self.spec_config.spec_dec_mode.is_eagle3())
411+
self.max_draft_len = spec_config.max_draft_len if not self.is_draft_model else 0
407412
else:
408413
self.without_logits = False
409414
self.max_draft_len = 0
@@ -464,7 +469,7 @@ def __init__(
464469

465470
@property
466471
def runtime_draft_len(self):
467-
return self.max_draft_len if self.enable_spec_decode else 0
472+
return self.max_draft_len if self.enable_spec_decode and not self.is_draft_model else 0
468473

469474
def set_lora_model_config(self,
470475
lora_target_modules: list[str],
@@ -902,6 +907,8 @@ def _load_model(self,
902907
moe_max_num_tokens: Optional[int] = None,
903908
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None,
904909
lora_config: Optional[LoraConfig] = None,
910+
model_wrapper: Optional[Callable[[torch.nn.Module],
911+
torch.nn.Module]] = None,
905912
**kwargs) -> DecoderModelForCausalLM:
906913
config = checkpoint_loader.load_config(
907914
checkpoint_dir,
@@ -1005,6 +1012,10 @@ def init_meta_tensor(t: torch.Tensor):
10051012
logger.info("moe_load_balancer finalize model done")
10061013

10071014
torch.cuda.current_stream().synchronize()
1015+
1016+
if model_wrapper is not None:
1017+
model = model_wrapper(model)
1018+
10081019
return model
10091020

10101021
def _call_load_weights(self, load_method, weights, weight_mapper):
@@ -1257,7 +1268,7 @@ def _prepare_tp_inputs(
12571268
past_seen_token_num = request.max_beam_num_tokens - 1
12581269
draft_lens.append(num_draft_tokens)
12591270

1260-
if self.enable_spec_decode and spec_config.spec_dec_mode.extend_ctx(
1271+
if self.enable_spec_decode and not self.is_draft_model and spec_config.spec_dec_mode.extend_ctx(
12611272
self.attn_backend):
12621273
# We're treating the prompt lengths as context requests here, so
12631274
# the the prompt lens should not include the cached tokens.

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,13 +252,19 @@ def create_py_executor(
252252
with mem_monitor.observe_creation_stage(
253253
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
254254
draft_spec_config = copy.copy(spec_config)
255-
draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
256-
if spec_config.load_format == "dummy":
257-
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY
258255
# The draft model won't have any draft tokens attached to
259256
# generation requests when we invoke it autoregressively
260257
draft_spec_config.max_draft_len = 0
261258

259+
def wrap_model(model):
260+
from tensorrt_llm._torch.speculative.eagle3 import ChainDrafter
261+
262+
return ChainDrafter(spec_config.max_draft_len, model)
263+
264+
draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
265+
if spec_config.load_format == "dummy":
266+
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY
267+
262268
draft_model_engine = PyTorchModelEngine(
263269
model_path=spec_config.speculative_model_dir,
264270
pytorch_backend_config=draft_pytorch_backend_config,
@@ -274,6 +280,7 @@ def create_py_executor(
274280
spec_config=draft_spec_config,
275281
checkpoint_loader=executor_config.checkpoint_loader,
276282
is_draft_model=True,
283+
model_wrapper=wrap_model,
277284
)
278285
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
279286
draft_model_engine.load_weights_from_target_model(

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import List, Optional, Set
2+
from typing import List, Optional, Set, Tuple, Any, Dict
33

44
import torch
55
from torch import nn
@@ -503,3 +503,86 @@ def prepare_1st_drafter_inputs(
503503
"attn_metadata": attn_metadata,
504504
"spec_metadata": spec_metadata,
505505
}
506+
507+
508+
class ChainDrafter(torch.nn.Module):
509+
510+
def __init__(self, max_draft_len: int, draft_model: torch.nn.Module):
511+
super().__init__()
512+
self.draft_model = draft_model
513+
self.config = self.draft_model.config
514+
self.model_config = self.draft_model.model_config
515+
self.max_draft_len = max_draft_len
516+
517+
def forward(self, input_ids, position_ids, attn_metadata, spec_metadata,
518+
**kwargs):
519+
batch_size = attn_metadata.num_seqs
520+
521+
logits = self.draft_model.forward(input_ids=input_ids,
522+
position_ids=position_ids,
523+
attn_metadata=attn_metadata,
524+
spec_metadata=spec_metadata)
525+
526+
new_draft_tokens = [self.sample(logits)]
527+
528+
if attn_metadata.is_cuda_graph:
529+
seq_len = attn_metadata._seq_lens[:batch_size].clone()
530+
seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone()
531+
532+
last_tokens_idx = torch.cumsum(
533+
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
534+
new_position_ids = position_ids[0, last_tokens_idx] + 1
535+
536+
attn_metadata._seq_lens[:batch_size].fill_(1)
537+
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
538+
attn_metadata.on_update()
539+
attn_metadata.kv_lens_cuda[:batch_size] += 1
540+
541+
attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1)
542+
attn_metadata.num_contexts = 0
543+
544+
spec_metadata.eagle3_resource_manager.is_first_draft = False
545+
spec_metadata.is_first_draft = False
546+
547+
old_write_indices = spec_metadata.hidden_states_write_indices
548+
549+
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
550+
old_write_indices[last_tokens_idx])
551+
spec_metadata.hidden_states_write_indices[:batch_size].copy_(
552+
torch.arange(
553+
batch_size,
554+
dtype=spec_metadata.hidden_states_write_indices.dtype,
555+
device=spec_metadata.hidden_states_write_indices.device))
556+
spec_metadata.num_tokens = batch_size
557+
558+
for i in range(self.max_draft_len - 1):
559+
logits = self.draft_model.forward(input_ids=new_draft_tokens[-1],
560+
position_ids=new_position_ids,
561+
attn_metadata=attn_metadata,
562+
spec_metadata=spec_metadata)
563+
new_draft_tokens.append(self.sample(logits))
564+
new_position_ids += 1
565+
attn_metadata.kv_lens_cuda[:batch_size] += 1
566+
if i == 0:
567+
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
568+
spec_metadata.hidden_states_write_indices[:batch_size])
569+
570+
spec_metadata.is_first_draft = True
571+
spec_metadata.eagle3_resource_manager.is_first_draft = True
572+
573+
if attn_metadata.is_cuda_graph:
574+
attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size])
575+
attn_metadata._seq_lens_cuda[:batch_size].copy_(
576+
seq_len_cuda[:batch_size])
577+
578+
return torch.stack(new_draft_tokens)
579+
580+
def sample(self, logits: torch.Tensor) -> torch.Tensor:
581+
tokens = torch.argmax(logits, dim=-1)
582+
d2t = self.draft_model.model.d2t.data
583+
584+
return tokens + d2t[tokens]
585+
586+
def load_weights_from_target_model(self,
587+
target_model: torch.nn.Module) -> None:
588+
self.draft_model.load_weights_from_target_model(target_model)

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def __init__(
7272
self._request_draft_logits = sampler.enable_mixed_sampler
7373
self.guided_decoder = guided_decoder
7474

75+
self.use_static_draft_loop = True
76+
7577
def _create_draft_request(self, request: LlmRequest,
7678
input_tokens: Optional[List]) -> LlmRequest:
7779
"""Create a draft request with common parameters."""
@@ -237,6 +239,8 @@ def _should_disable_cuda_graph(
237239
"""Check if CUDA graph should be disabled for the current forward pass."""
238240
if previous_batch is not None:
239241
return False
242+
if self.use_static_draft_loop:
243+
return False
240244
return self.spec_config.spec_dec_mode.needs_kv_cache_recompute()
241245

242246
def _forward_draft_model(
@@ -257,8 +261,8 @@ def _forward_draft_model(
257261
new_tensors_device=new_tensors_device)
258262

259263
# Handle d2t data if available
260-
if hasattr(self.draft_model_engine.model.model, 'd2t'):
261-
outputs['d2t'] = self.draft_model_engine.model.model.d2t.data
264+
# if hasattr(self.draft_model_engine.model.model, 'd2t'):
265+
# outputs['d2t'] = self.draft_model_engine.model.model.d2t.data
262266

263267
return outputs
264268

@@ -377,6 +381,17 @@ def prepare_draft_tokens(
377381

378382
# Initial forward pass
379383
outputs = self._forward_draft_model(draft_batch, resource_manager)
384+
385+
if self.use_static_draft_loop:
386+
outputs_host = outputs.cpu()
387+
for token_idx in range(self.max_draft_tokens):
388+
for req_idx, req in enumerate(draft_batch.all_requests()):
389+
target_req = req_id_to_old_request[req.py_request_id]
390+
target_req.py_draft_tokens.append(
391+
outputs_host[token_idx][req_idx])
392+
393+
return
394+
380395
self._execute_guided_decoder(draft_batch,
381396
outputs['logits'],
382397
d2t=outputs.get('d2t'))

0 commit comments

Comments
 (0)