Skip to content

Commit 4721451

Browse files
committed
[feat] Improve 2-model perf
Signed-off-by: Mike Iovine <[email protected]>
1 parent 200db3b commit 4721451

File tree

7 files changed

+189
-21
lines changed

7 files changed

+189
-21
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Generic, Optional, Tuple
1+
from typing import Dict, Generic, Optional, Tuple
22

33
import torch
44
from torch import nn
@@ -293,18 +293,6 @@ def load_weights_from_target_model(self,
293293
if self.load_lm_head_from_target:
294294
self.lm_head = target_model.lm_head
295295

296-
# TODO: should input/position IDs be included in this? Keeping it implicit
297-
# for now since the shapes/dtypes are the same across all models we have.
298-
def get_warmup_extra_inputs(self, batch_size: int,
299-
num_tokens: int) -> Dict[str, Any]:
300-
301-
hidden_states = torch.empty(batch_size * num_tokens,
302-
self.model.hidden_size,
303-
dtype=self.model.dtype,
304-
device='cuda')
305-
306-
return {'hidden_states': hidden_states}
307-
308296
def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
309297
"""
310298
Hack for eagle3. We might need to run a matmul to reduce

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import weakref
1010
from abc import ABC, abstractmethod
1111
from contextlib import contextmanager
12-
from typing import Any, Dict, Optional, Tuple
12+
from typing import Any, Callable, Dict, Optional, Tuple
1313

1414
import torch
1515
import torch._dynamo.config
@@ -274,6 +274,8 @@ def __init__(
274274
spec_config: Optional["DecodingBaseConfig"] = None,
275275
lora_config: Optional[LoraConfig] = None,
276276
is_draft_model: bool = False,
277+
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
278+
torch.nn.Module]] = None,
277279
):
278280
self.ub_buffers = None
279281
self.batch_size = batch_size
@@ -309,7 +311,8 @@ def __init__(
309311
max_num_tokens=max_num_tokens,
310312
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,
311313
moe_load_balancer=pytorch_backend_config.moe_load_balancer,
312-
lora_config=lora_config)
314+
lora_config=lora_config,
315+
drafting_loop_wrapper=drafting_loop_wrapper)
313316
# In case that some tests use stub models and override `_load_model`.
314317
if not hasattr(self.model, 'extra_attrs'):
315318
self.model.extra_attrs = {}
@@ -402,7 +405,7 @@ def __init__(
402405
dtype=torch.int,
403406
device='cuda')
404407
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
405-
)
408+
) or self.model_is_wrapped
406409
self.max_draft_len = spec_config.max_draft_len
407410
else:
408411
self.without_logits = False
@@ -902,6 +905,8 @@ def _load_model(self,
902905
moe_max_num_tokens: Optional[int] = None,
903906
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None,
904907
lora_config: Optional[LoraConfig] = None,
908+
drafting_loop_wrapper: Optional[Callable[
909+
[torch.nn.Module], torch.nn.Module]] = None,
905910
**kwargs) -> DecoderModelForCausalLM:
906911
config = checkpoint_loader.load_config(
907912
checkpoint_dir,
@@ -1005,6 +1010,13 @@ def init_meta_tensor(t: torch.Tensor):
10051010
logger.info("moe_load_balancer finalize model done")
10061011

10071012
torch.cuda.current_stream().synchronize()
1013+
1014+
if drafting_loop_wrapper is not None:
1015+
model = drafting_loop_wrapper(model)
1016+
self.model_is_wrapped = True
1017+
else:
1018+
self.model_is_wrapped = False
1019+
10081020
return model
10091021

10101022
def _call_load_weights(self, load_method, weights, weight_mapper):

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,13 +252,28 @@ def create_py_executor(
252252
with mem_monitor.observe_creation_stage(
253253
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
254254
draft_spec_config = copy.copy(spec_config)
255-
draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
256-
if spec_config.load_format == "dummy":
257-
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY
258255
# The draft model won't have any draft tokens attached to
259256
# generation requests when we invoke it autoregressively
260257
draft_spec_config.max_draft_len = 0
261258

259+
use_chain_drafter = (
260+
executor_config.guided_decoding_config is None
261+
and not pytorch_backend_config.enable_mixed_sampler)
262+
263+
if use_chain_drafter:
264+
265+
def drafting_loop_wrapper(model):
266+
from tensorrt_llm._torch.speculative.drafting_loops import \
267+
ChainDrafter
268+
269+
return ChainDrafter(spec_config.max_draft_len, model)
270+
else:
271+
drafting_loop_wrapper = None
272+
273+
draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
274+
if spec_config.load_format == "dummy":
275+
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY
276+
262277
draft_model_engine = PyTorchModelEngine(
263278
model_path=spec_config.speculative_model_dir,
264279
pytorch_backend_config=draft_pytorch_backend_config,
@@ -274,6 +289,7 @@ def create_py_executor(
274289
spec_config=draft_spec_config,
275290
checkpoint_loader=executor_config.checkpoint_loader,
276291
is_draft_model=True,
292+
drafting_loop_wrapper=drafting_loop_wrapper,
277293
)
278294
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
279295
draft_model_engine.load_weights_from_target_model(
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
This module contains capturable drafting loops for speculative decoding.
3+
4+
Thes are torch modules wrap another draft model. The wrapped module
5+
is supposed to invoke the draft model autoregressively and invoke
6+
a sampling algorithm to obtain draft tokens. By structuring the code
7+
like this, we are able to avoid host overhead: the entire drafting process
8+
for speculation can be launched as a single CUDA graph.
9+
"""
10+
11+
import torch
12+
from contextlib import contextmanager
13+
14+
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
15+
from tensorrt_llm._torch.speculative.interface import SpecMetadata
16+
from tensorrt_llm._torch.speculative.eagle3 import Eagle3SpecMetadata
17+
18+
19+
@contextmanager
20+
def save_metadata_state(attn_metadata: AttentionMetadata,
21+
spec_metadata: SpecMetadata) -> None:
22+
batch_size = attn_metadata.num_seqs
23+
if attn_metadata.is_cuda_graph:
24+
seq_len = attn_metadata._seq_lens[:batch_size].clone()
25+
seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone()
26+
27+
try:
28+
yield
29+
finally:
30+
if attn_metadata.is_cuda_graph:
31+
attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size])
32+
attn_metadata._seq_lens_cuda[:batch_size].copy_(
33+
seq_len_cuda[:batch_size])
34+
35+
spec_metadata.reset()
36+
37+
38+
def prepare_for_generation(attn_metadata: AttentionMetadata,
39+
spec_metadata: SpecMetadata,
40+
last_tokens_idx: torch.Tensor) -> None:
41+
batch_size = attn_metadata.num_seqs
42+
attn_metadata._seq_lens[:batch_size].fill_(1)
43+
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
44+
attn_metadata.on_update()
45+
attn_metadata.kv_lens_cuda[:batch_size] += 1
46+
47+
attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1)
48+
attn_metadata.num_contexts = 0
49+
50+
spec_metadata.num_tokens = batch_size
51+
52+
if isinstance(spec_metadata, Eagle3SpecMetadata):
53+
spec_metadata.eagle3_resource_manager.is_first_draft = False
54+
spec_metadata.is_first_draft = False
55+
56+
old_write_indices = spec_metadata.hidden_states_write_indices
57+
58+
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
59+
old_write_indices[last_tokens_idx])
60+
spec_metadata.hidden_states_write_indices[:batch_size].copy_(
61+
torch.arange(
62+
batch_size,
63+
dtype=spec_metadata.hidden_states_write_indices.dtype,
64+
device=spec_metadata.hidden_states_write_indices.device))
65+
66+
67+
class ChainDrafter(torch.nn.Module):
68+
69+
def __init__(self, max_draft_len: int, draft_model: torch.nn.Module):
70+
super().__init__()
71+
self.draft_model = draft_model
72+
self.config = self.draft_model.config
73+
self.model_config = self.draft_model.model_config
74+
self.max_draft_len = max_draft_len
75+
76+
def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
77+
attn_metadata: AttentionMetadata,
78+
spec_metadata: AttentionMetadata, **kwargs) -> None:
79+
80+
logits = self.draft_model.forward(input_ids=input_ids,
81+
position_ids=position_ids,
82+
attn_metadata=attn_metadata,
83+
spec_metadata=spec_metadata)
84+
85+
new_draft_tokens = [self.sample(logits)]
86+
87+
with save_metadata_state(attn_metadata, spec_metadata):
88+
batch_size = attn_metadata.num_seqs
89+
last_tokens_idx = torch.cumsum(
90+
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
91+
new_position_ids = position_ids[0, last_tokens_idx] + 1
92+
93+
prepare_for_generation(attn_metadata, spec_metadata,
94+
last_tokens_idx)
95+
96+
for i in range(self.max_draft_len - 1):
97+
logits = self.draft_model.forward(
98+
input_ids=new_draft_tokens[-1],
99+
position_ids=new_position_ids,
100+
attn_metadata=attn_metadata,
101+
spec_metadata=spec_metadata)
102+
new_draft_tokens.append(self.sample(logits))
103+
new_position_ids += 1
104+
attn_metadata.kv_lens_cuda[:batch_size] += 1
105+
if i == 0:
106+
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
107+
spec_metadata.hidden_states_write_indices[:batch_size])
108+
109+
return torch.stack(new_draft_tokens)
110+
111+
def sample(self, logits: torch.Tensor) -> torch.Tensor:
112+
tokens = torch.argmax(logits, dim=-1)
113+
if hasattr(self.draft_model.model, "d2t"):
114+
d2t = self.draft_model.model.d2t.data
115+
return tokens + d2t[tokens]
116+
117+
return tokens
118+
119+
def load_weights_from_target_model(self,
120+
target_model: torch.nn.Module) -> None:
121+
self.draft_model.load_weights_from_target_model(target_model)

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ def get_hidden_states(self):
185185
hidden_states = hidden_states[:, :self.hidden_size]
186186
return hidden_states
187187

188+
def reset(self):
189+
self.is_first_draft = True
190+
self.eagle3_resource_manager.is_first_draft = True
191+
188192

189193
@dataclass
190194
class Eagle3OneModelSpecMetadata(SpecMetadata):

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,9 @@ def all_rank_num_tokens(self, value: Optional[List[int]]):
209209
value = value if value is not SpecMetadata.all_rank_num_tokens else None
210210
self._all_rank_num_tokens = value
211211
self.all_rank_max_num_tokens = max(value) if value is not None else None
212+
213+
def reset(self):
214+
"""
215+
Currently used by 2-model static drafting loops only. Used to reset any spec metadata
216+
to its original state after the drafting loop exists. Does nothing by default.
217+
"""

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ def __init__(
7272
self._request_draft_logits = sampler.enable_mixed_sampler
7373
self.guided_decoder = guided_decoder
7474

75+
self.use_static_draft_loop = draft_model_engine.model_is_wrapped
76+
if self.use_static_draft_loop:
77+
# TODO: enable sampling/guided decoding on static draft loop
78+
assert guided_decoder is None
79+
assert not sampler.enable_mixed_sampler
80+
7581
def _create_draft_request(self, request: LlmRequest,
7682
input_tokens: Optional[List]) -> LlmRequest:
7783
"""Create a draft request with common parameters."""
@@ -237,6 +243,8 @@ def _should_disable_cuda_graph(
237243
"""Check if CUDA graph should be disabled for the current forward pass."""
238244
if previous_batch is not None:
239245
return False
246+
if self.use_static_draft_loop:
247+
return False
240248
return self.spec_config.spec_dec_mode.needs_kv_cache_recompute()
241249

242250
def _forward_draft_model(
@@ -256,8 +264,10 @@ def _forward_draft_model(
256264
resource_manager,
257265
new_tensors_device=new_tensors_device)
258266

259-
# Handle d2t data if available
260-
if hasattr(self.draft_model_engine.model.model, 'd2t'):
267+
# Handle d2t data if available. Static drafting loops should incorporate d2t
268+
# in their implementations.
269+
if not self.use_static_draft_loop and hasattr(
270+
self.draft_model_engine.model.model, 'd2t'):
261271
outputs['d2t'] = self.draft_model_engine.model.model.d2t.data
262272

263273
return outputs
@@ -377,6 +387,17 @@ def prepare_draft_tokens(
377387

378388
# Initial forward pass
379389
outputs = self._forward_draft_model(draft_batch, resource_manager)
390+
391+
if self.use_static_draft_loop:
392+
outputs_host = outputs.cpu()
393+
for token_idx in range(self.max_draft_tokens):
394+
for req_idx, req in enumerate(draft_batch.all_requests()):
395+
target_req = req_id_to_old_request[req.py_request_id]
396+
target_req.py_draft_tokens.append(
397+
outputs_host[token_idx][req_idx])
398+
399+
return
400+
380401
self._execute_guided_decoder(draft_batch,
381402
outputs['logits'],
382403
d2t=outputs.get('d2t'))

0 commit comments

Comments
 (0)